Переглянути джерело

Support for Enum in HLSL 2017 (#309)

This change is to support enum types in HLSL 2017. Enum/enum class with fixed underlying type to match C+11 standard. Enum should only be enabled with -HV 2017 option on dxc.
Young Kim 8 роки тому
батько
коміт
4857166d1e

+ 2 - 0
tools/clang/include/clang/Basic/DiagnosticParseKinds.td

@@ -1022,6 +1022,8 @@ def warn_hlsl_effect_state_block : Warning <
 def warn_hlsl_effect_technique : Warning <
   "effect technique ignored - effect syntax is deprecated">,
   InGroup< HLSLEffectsSyntax >;
+def err_hlsl_enum : Error<
+  "enum is unsupported in HLSL before 2017">;
 
 // OpenMP support.
 def warn_pragma_omp_ignored : Warning<

+ 1 - 1
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -825,7 +825,7 @@ unsigned CGMSHLSLRuntime::ConstructStructAnnotation(DxilStructAnnotation *annota
 }
 
 static bool IsElementInputOutputType(QualType Ty) {
-  return Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty);
+  return Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty) || Ty->isEnumeralType();
 }
 
 // Return the size for constant buffer of each decl.

+ 3 - 3
tools/clang/lib/Frontend/CompilerInvocation.cpp

@@ -1722,9 +1722,9 @@ static void ParseLangArgs(LangOptions &Opts, ArgList &Args, InputKind IK,
 #else
   StringRef ver = Args.getLastArgValue(OPT_hlsl_version);
   Opts.HLSL2015 = Opts.HLSL2016 = Opts.HLSL2017 = false;
-  if (ver.empty() || ver == "2016") { Opts.HLSL2016 = true; }   // Default to 2016
-  else if           (ver == "2015") { Opts.HLSL2015 = true; }
-  else if           (ver == "2017") { Opts.HLSL2017 = true; }
+  if (ver.empty() || ver == "2016" || ver == "-2016") { Opts.HLSL2016 = true; }   // Default to 2016
+  else if           (ver == "2015" || ver == "-2015") { Opts.HLSL2015 = true; }
+  else if           (ver == "2017" || ver == "-2017") { Opts.HLSL2017 = true; }
   else {
     Diags.Report(diag::err_drv_invalid_value)
       << Args.getLastArg(OPT_hlsl_version)->getAsString(Args)

+ 9 - 8
tools/clang/lib/Parse/ParseDecl.cpp

@@ -4402,9 +4402,8 @@ void Parser::ParseEnumSpecifier(SourceLocation StartLoc, DeclSpec &DS,
                                 const ParsedTemplateInfo &TemplateInfo,
                                 AccessSpecifier AS, DeclSpecContext DSC) {
   // HLSL Change Starts
-  if (getLangOpts().HLSL) {
-    Diag(Tok, diag::err_hlsl_unsupported_construct) << "enum";
-
+  if (getLangOpts().HLSL && !getLangOpts().HLSL2017) {
+    Diag(Tok, diag::err_hlsl_enum);
     // Skip the rest of this declarator, up until the comma or semicolon.
     SkipUntil(tok::comma, StopAtSemi);
     return;
@@ -4423,15 +4422,17 @@ void Parser::ParseEnumSpecifier(SourceLocation StartLoc, DeclSpec &DS,
   MaybeParseGNUAttributes(attrs);
   MaybeParseCXX11Attributes(attrs);
   MaybeParseMicrosoftDeclSpecs(attrs);
-  assert(!getLangOpts().HLSL); // HLSL Change: in lieu of MaybeParseHLSLAttributes - enums not allowed
+  MaybeParseHLSLAttributes(attrs);
 
   SourceLocation ScopedEnumKWLoc;
   bool IsScopedUsingClassTag = false;
 
   // In C++11, recognize 'enum class' and 'enum struct'.
   if (Tok.isOneOf(tok::kw_class, tok::kw_struct)) {
-    Diag(Tok, getLangOpts().CPlusPlus11 ? diag::warn_cxx98_compat_scoped_enum
-                                        : diag::ext_scoped_enum);
+    // HLSL Change: Supress C++11 warning
+    if (!getLangOpts().HLSL)
+      Diag(Tok, getLangOpts().CPlusPlus11 ? diag::warn_cxx98_compat_scoped_enum
+                                          : diag::ext_scoped_enum);
     IsScopedUsingClassTag = Tok.is(tok::kw_class);
     ScopedEnumKWLoc = ConsumeToken();
 
@@ -4461,7 +4462,7 @@ void Parser::ParseEnumSpecifier(SourceLocation StartLoc, DeclSpec &DS,
 
   bool AllowFixedUnderlyingType = AllowDeclaration &&
     (getLangOpts().CPlusPlus11 || getLangOpts().MicrosoftExt ||
-     getLangOpts().ObjC2);
+     getLangOpts().ObjC2 || getLangOpts().HLSL2017);
 
   CXXScopeSpec &SS = DS.getTypeSpecScope();
   if (getLangOpts().CPlusPlus) {
@@ -4805,7 +4806,7 @@ void Parser::ParseEnumBody(SourceLocation StartLoc, Decl *EnumDecl) {
             << 1 /*enumerator*/;
       ParseCXX11Attributes(attrs);
     }
-    assert(!getLangOpts().HLSL); // HLSL Change: in lieu of MaybeParseHLSLAttributes - enums not allowed
+    MaybeParseHLSLAttributes(attrs);
 
     SourceLocation EqualLoc;
     ExprResult AssignedVal;

+ 3 - 1
tools/clang/lib/Sema/SemaCXXScopeSpec.cpp

@@ -617,7 +617,9 @@ bool Sema::BuildCXXNestedNameSpecifier(Scope *S,
   bool AcceptSpec = isAcceptableNestedNameSpecifier(SD, &IsExtension);
   if (!AcceptSpec && IsExtension) {
     AcceptSpec = true;
-    Diag(IdentifierLoc, diag::ext_nested_name_spec_is_enum);
+    // HLSL Change: Suppress c++11 extension warnings for nested name specifier in HLSL2017
+    if (!getLangOpts().HLSL2017)
+        Diag(IdentifierLoc, diag::ext_nested_name_spec_is_enum);
   }
   if (AcceptSpec) {
     if (!ObjectType.isNull() && !ObjectTypeSearchedInScope &&

+ 49 - 3
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -60,6 +60,7 @@ enum ArBasicKind {
   AR_BASIC_MIN12INT,
   AR_BASIC_MIN16INT,
   AR_BASIC_MIN16UINT,
+  AR_BASIC_ENUM,
 
   AR_BASIC_COUNT,
 
@@ -77,6 +78,7 @@ enum ArBasicKind {
   //
 
   AR_BASIC_POINTER,
+  AR_BASIC_ENUM_CLASS,
 
   AR_OBJECT_NULL,
   AR_OBJECT_STRING,
@@ -247,6 +249,7 @@ enum ArBasicKind {
 #define BPROP_PRIMITIVE         0x00100000  // Whether the type is a primitive scalar type.
 #define BPROP_MIN_PRECISION     0x00200000  // Whether the type is qualified with a minimum precision.
 #define BPROP_ROVBUFFER         0x00400000  // Whether the type is a ROV object.
+#define BPROP_ENUM              0x00800000  // Whether the type is a enum
 
 #define GET_BPROP_PRIM_KIND(_Props) \
     ((_Props) & (BPROP_BOOLEAN | BPROP_INTEGER | BPROP_FLOATING))
@@ -289,6 +292,9 @@ enum ArBasicKind {
 #define IS_BPROP_UNSIGNABLE(_Props) \
     (IS_BPROP_AINT(_Props) && GET_BPROP_BITS(_Props) != BPROP_BITS12)
 
+#define IS_BPROP_ENUM(_Props) \
+    (((_Props) & BPROP_ENUM) != 0)
+
 const UINT g_uBasicKindProps[] =
 {
   BPROP_PRIMITIVE | BPROP_BOOLEAN | BPROP_INTEGER | BPROP_NUMERIC | BPROP_BITS0,  // AR_BASIC_BOOL
@@ -316,6 +322,7 @@ const UINT g_uBasicKindProps[] =
   BPROP_PRIMITIVE | BPROP_NUMERIC | BPROP_INTEGER | BPROP_BITS16 | BPROP_MIN_PRECISION,   // AR_BASIC_MIN16INT
   BPROP_PRIMITIVE | BPROP_NUMERIC | BPROP_INTEGER | BPROP_UNSIGNED | BPROP_BITS16 | BPROP_MIN_PRECISION,  // AR_BASIC_MIN16UINT
 
+  BPROP_ENUM | BPROP_NUMERIC | BPROP_INTEGER, // AR_BASIC_ENUM
   BPROP_OTHER,  // AR_BASIC_COUNT
 
   //
@@ -332,10 +339,12 @@ const UINT g_uBasicKindProps[] =
   //
 
   BPROP_POINTER,  // AR_BASIC_POINTER
+  BPROP_ENUM, // AR_BASIC_ENUM_CLASS
 
   BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_NULL
   BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_STRING
 
+
   // BPROP_OBJECT | BPROP_TEXTURE, // AR_OBJECT_TEXTURE
   BPROP_OBJECT | BPROP_TEXTURE, // AR_OBJECT_TEXTURE1D
   BPROP_OBJECT | BPROP_TEXTURE, // AR_OBJECT_TEXTURE1D_ARRAY
@@ -461,6 +470,9 @@ C_ASSERT(ARRAYSIZE(g_uBasicKindProps) == AR_BASIC_MAXIMUM_COUNT);
 #define IS_BASIC_UNSIGNABLE(_Kind) \
     IS_BPROP_UNSIGNABLE(GetBasicKindProps(_Kind))
 
+#define IS_BASIC_ENUM(_Kind) \
+    IS_BPROP_ENUM(GetBasicKindProps(_Kind))
+
 #define BITWISE_ENUM_OPS(_Type)                                         \
 inline _Type operator|(_Type F1, _Type F2)                              \
 {                                                                       \
@@ -1290,12 +1302,15 @@ const char* g_ArBasicTypeNames[] =
   "int", "uint", "long", "ulong",
   "min10float", "min16float",
   "min12int", "min16int", "min16uint",
+  "enum",
 
   "<count>",
   "<none>",
   "<unknown>",
   "<nocast>",
   "<pointer>",
+  "enum class",
+
   "null",
   "string",
   // "texture",
@@ -3120,6 +3135,7 @@ public:
     }
 
     if (type->isBuiltinType()) return AR_TOBJ_BASIC;
+    if (type->isEnumeralType()) return AR_TOBJ_BASIC;
 
     return AR_TOBJ_INVALID;
   }
@@ -3210,7 +3226,11 @@ public:
       case BuiltinType::LitInt: return AR_BASIC_LITERAL_INT;
       }
     }
-
+    if (const EnumType *ET = dyn_cast<EnumType>(type)) {
+        if (ET->getDecl()->isScopedUsingClassTag())
+            return AR_BASIC_ENUM_CLASS;
+        return AR_BASIC_ENUM;
+    }
     return AR_BASIC_UNKNOWN;
   }
 
@@ -3357,7 +3377,7 @@ public:
     case AR_OBJECT_APPEND_STRUCTURED_BUFFER:
     case AR_OBJECT_CONSUME_STRUCTURED_BUFFER:
     case AR_OBJECT_WAVE:
-    {
+{
         const ArBasicKind* match = std::find(g_ArBasicKindsAsTypes, &g_ArBasicKindsAsTypes[_countof(g_ArBasicKindsAsTypes)], kind);
         DXASSERT(match != &g_ArBasicKindsAsTypes[_countof(g_ArBasicKindsAsTypes)], "otherwise can't find constant in basic kinds");
         size_t index = match - g_ArBasicKindsAsTypes;
@@ -5218,6 +5238,10 @@ static bool UnaryOperatorKindDisallowsBool(UnaryOperatorKind Opc)
     Opc == UnaryOperatorKind::UO_PostDec || Opc == UnaryOperatorKind::UO_PostInc;
 }
 
+static bool IsIncrementOp(UnaryOperatorKind Opc) {
+  return Opc == UnaryOperatorKind::UO_PreInc || Opc == UnaryOperatorKind::UO_PostInc;
+}
+
 /// <summary>
 /// Checks whether the specified AR_TOBJ* value is a primitive or aggregate of primitive elements
 /// (as opposed to a built-in object like a sampler or texture, or a void type).
@@ -7326,7 +7350,13 @@ bool HLSLExternalSource::CanConvert(
     {
       Remarks |= TYPE_CONVERSION_ELT_TRUNCATION;
     }
-
+    // enum -> enum not allowed
+    if ((SourceInfo.EltKind == AR_BASIC_ENUM &&
+        TargetInfo.EltKind == AR_BASIC_ENUM) ||
+        SourceInfo.EltKind == AR_BASIC_ENUM_CLASS ||
+        TargetInfo.EltKind == AR_BASIC_ENUM_CLASS) {
+      return false;
+    }
     if (SourceInfo.EltKind != TargetInfo.EltKind)
     {
       if (TargetInfo.EltKind == AR_BASIC_UNKNOWN ||
@@ -7338,6 +7368,16 @@ bool HLSLExternalSource::CanConvert(
       {
         ComponentConversion = ICK_Boolean_Conversion;
       }
+      else if (IS_BASIC_ENUM(TargetInfo.EltKind))
+      {
+        // conversion to enum type not allowed
+        return false;
+      }
+      else if (IS_BASIC_ENUM(SourceInfo.EltKind))
+      {
+        // enum -> int/float
+        ComponentConversion = ICK_Integral_Conversion;
+      }
       else
       {
         bool targetIsInt = IS_BASIC_AINT(TargetInfo.EltKind);
@@ -7824,6 +7864,12 @@ QualType HLSLExternalSource::CheckUnaryOpForHLSL(
   ArBasicKind elementKind = GetTypeElementKind(expr->getType());
 
   if (UnaryOperatorKindRequiresModifiableValue(Opc)) {
+    if (elementKind == AR_BASIC_ENUM) {
+      bool isInc = IsIncrementOp(Opc);
+      m_sema->Diag(OpLoc, diag::err_increment_decrement_enum) << isInc << expr->getType();
+      return QualType();
+    }
+
     extern bool CheckForModifiableLvalue(Expr *E, SourceLocation Loc, Sema &S);
     if (CheckForModifiableLvalue(expr, OpLoc, *m_sema))
       return QualType();

+ 1 - 1
tools/clang/lib/Sema/SemaOverload.cpp

@@ -5132,7 +5132,7 @@ static ExprResult CheckConvertedConstantExpression(Sema &S, Expr *From,
                                                    QualType T, APValue &Value,
                                                    Sema::CCEKind CCE,
                                                    bool RequireInt) {
-  assert(S.getLangOpts().CPlusPlus11 &&
+  assert(S.getLangOpts().CPlusPlus11 || S.getLangOpts().HLSL2017 &&
          "converted constant expression outside C++11");
 
   if (checkPlaceholderForOverload(S, From))

+ 2 - 2
tools/clang/lib/Sema/SemaStmt.cpp

@@ -392,7 +392,7 @@ Sema::ActOnCaseStmt(SourceLocation CaseLoc, Expr *LHSVal,
     return StmtError();
   LHSVal = LHS.get();
 
-  if (!getLangOpts().CPlusPlus11) {
+  if (!getLangOpts().CPlusPlus11 && !getLangOpts().HLSL2017) {
     // C99 6.8.4.2p3: The expression shall be an integer constant.
     // However, GCC allows any evaluatable integer expression.
     if (!LHSVal->isTypeDependent() && !LHSVal->isValueDependent()) {
@@ -860,7 +860,7 @@ Sema::ActOnFinishSwitchStmt(SourceLocation SwitchLoc, Stmt *Switch,
 
       llvm::APSInt LoVal;
 
-      if (getLangOpts().CPlusPlus11) {
+      if (getLangOpts().CPlusPlus11 || getLangOpts().HLSL2017) {
         // C++11 [stmt.switch]p2: the constant-expression shall be a converted
         // constant expression of the promoted type of the switch condition.
         ExprResult ConvLo =

+ 32 - 0
tools/clang/test/CodeGenHLSL/enum1.hlsl

@@ -0,0 +1,32 @@
+// RUN: %dxc -E main -T ps_6_1 -HV 2017 %s | FileCheck %s
+
+// CHECK: call void @dx.op.storeOutput.i32(i32 5, i32 0, i32 0, i8 0, i32 1)
+// CHECK: call void @dx.op.storeOutput.i32(i32 5, i32 0, i32 0, i8 1, i32 2)
+// CHECK: call void @dx.op.storeOutput.i32(i32 5, i32 0, i32 0, i8 2, i32 3)
+// CHECK: call void @dx.op.storeOutput.i32(i32 5, i32 0, i32 0, i8 3, i32 4)
+
+enum class MyEnum {
+    FIRST,
+    SECOND,
+    THIRD,
+    FOURTH,
+};
+
+int f(MyEnum v) {
+    switch (v) {
+        case MyEnum::FIRST:
+            return 1;
+        case MyEnum::SECOND:
+            return 2;
+        case MyEnum::THIRD:
+            return 3;
+        case MyEnum::FOURTH:
+            return 4;
+        default:
+            return 0;
+    }
+}
+
+int4 main() : SV_Target {
+    return int4(f(MyEnum::FIRST), f(MyEnum::SECOND), f(MyEnum::THIRD), f(MyEnum::FOURTH));
+}

+ 17 - 0
tools/clang/test/CodeGenHLSL/enum2.hlsl

@@ -0,0 +1,17 @@
+// RUN: %dxc -E main -T ps_6_0 -HV 2017 %s | FileCheck %s
+
+// CHECK: call void @dx.op.storeOutput.i32(i32 5, i32 0, i32 0, i8 0, i32 10)
+// CHECK: call void @dx.op.storeOutput.i32(i32 5, i32 0, i32 0, i8 1, i32 -2)
+// CHECK: call void @dx.op.storeOutput.i32(i32 5, i32 0, i32 0, i8 2, i32 48)
+// CHECK: call void @dx.op.storeOutput.i32(i32 5, i32 0, i32 0, i8 3, i32 -25)
+
+enum Vertex : int {
+    FIRST = 10,
+    SECOND = -2,
+    THIRD = 48,
+    FOURTH = -25,
+};
+
+int4 main(float4 col : COLOR) : SV_Target {
+    return float4(FIRST, SECOND, THIRD, FOURTH);
+}

+ 13 - 0
tools/clang/test/CodeGenHLSL/enum3.hlsl

@@ -0,0 +1,13 @@
+// RUN: %dxc -E main -T ps_6_1 -HV 2017 %s | FileCheck %s
+
+// CHECK: dx.op.attributeAtVertex
+
+enum Vertex {
+    FIRST,
+    SECOND,
+    THIRD
+};
+
+int4 main(nointerpolation float4 col : COLOR) : SV_Target {
+    return GetAttributeAtVertex(col, Vertex::THIRD);
+}

+ 24 - 0
tools/clang/test/CodeGenHLSL/enum4.hlsl

@@ -0,0 +1,24 @@
+// RUN: %dxc -E main -T ps_6_1 -HV 2017 %s | FileCheck %s
+
+// CHECK: call void @dx.op.storeOutput.i32(i32 5, i32 0, i32 0, i8 0, i32 1)
+
+enum Vertex : int {
+    FIRST,
+    SECOND,
+    THIRD
+};
+
+int4 getValueInt(int i) {
+    switch (i) {
+        case 0:
+            return int4(1,1,1,1);
+        case 1:
+            return int4(2,2,2,2);
+        case 2:
+            return int4(3,3,3,3);
+    }
+}
+
+int4 main(float4 col : COLOR) : SV_Target {
+    return getValueInt(Vertex::FIRST);
+}

+ 13 - 0
tools/clang/test/CodeGenHLSL/enum5.hlsl

@@ -0,0 +1,13 @@
+// RUN: %dxc -E main -T ps_6_1 -HV 2017 %s | FileCheck %s
+
+// CHECK: fadd
+
+enum Vertex {
+    FIRST,
+    SECOND,
+    THIRD
+};
+
+float4 main(float4 col : COLOR) : SV_Target {
+    return !Vertex::FIRST + col;
+}

+ 3 - 3
tools/clang/test/HLSL/cpp-errors.hlsl

@@ -251,8 +251,8 @@ void vla(int size) {
   return n[0];
 }
 
-enum MyEnum  { MyEnum_MyVal1, MyEnum_MyVal2 }; // expected-error {{enum is unsupported in HLSL}} expected-warning {{declaration does not declare anything}}
-enum class MyEnumWithClass { MyEnumWithClass_MyVal1, MyEnumWithClass_MyVal2 }; // expected-error {{enum is unsupported in HLSL}} expected-warning {{declaration does not declare anything}}
+enum MyEnum  { MyEnum_MyVal1, MyEnum_MyVal2 }; // expected-error {{enum is unsupported in HLSL before 2017}} expected-warning {{declaration does not declare anything}}
+enum class MyEnumWithClass { MyEnumWithClass_MyVal1, MyEnumWithClass_MyVal2 }; // expected-error {{enum is unsupported in HLSL before 2017}} expected-warning {{declaration does not declare anything}}
 
 float4 fn_with_semantic() : SV_Target0{
   return 0;
@@ -647,4 +647,4 @@ float4 plain(float4 param4 /* : FOO */) /*: FOO */{
   const j; // expected-error {{HLSL requires a type specifier for all declarations}}
   long long ll; // expected-error {{'long' is a reserved keyword in HLSL}} expected-error {{'long' is a reserved keyword in HLSL}} expected-error {{HLSL requires a type specifier for all declarations}}
   return is_supported();
-}
+}

+ 202 - 0
tools/clang/test/HLSL/enums.hlsl

@@ -0,0 +1,202 @@
+// RUN: %clang_cc1 -HV -2017 -fsyntax-only -ffreestanding -verify %s
+
+enum MyEnum {
+    ZERO,
+    ONE,
+    TWO,
+    THREE,
+    FOUR,
+    TEN = 10,
+};
+
+enum class MyEnumClass {
+  ZEROC,
+  ONEC,
+  TWOC,
+  THREEC,
+  FOURC = 4,
+};
+
+enum MyEnumBool : bool {
+  ZEROB = true,
+};
+
+enum MyEnumInt : int {
+  ZEROI,
+  ONEI,
+  TWOI,
+  THREEI,
+  FOURI,
+  NEGONEI = -1,
+};
+
+enum MyEnumUInt : uint {
+  ZEROU,
+  ONEU,
+  TWOU,
+  THREEU,
+  FOURU,
+};
+
+enum MyEnumDWord : dword {
+  ZERODWORD,
+};
+
+enum MyEnum64 : uint64_t {
+  ZERO64,
+  ONE64,
+  TWO64,
+  THREE64,
+  FOUR64,
+};
+
+enum MyEnumMin16int : min16int {
+  ZEROMIN16INT,
+};
+
+enum MyEnumMin16uint : min16uint {
+  ZEROMIN16UINT,
+};
+
+enum MyEnumHalf : half {                                    /* expected-error {{non-integral type 'half' is an invalid underlying type}} */
+  ZEROH,
+};
+
+enum MyEnumFloat : float {                                  /* expected-error {{non-integral type 'float' is an invalid underlying type}} */
+  ZEROF,
+};
+
+enum MyEnumDouble : double {                                /* expected-error {{non-integral type 'double' is an invalid underlying type}} */
+  ZEROD,
+};
+
+enum MyEnumMin16Float : min16float {                        /* expected-error {{non-integral type 'min16float' is an invalid underlying type}} */
+  ZEROMIN16F,
+};
+
+enum MyEnumMin10Float : min10float {                        /* expected-error {{non-integral type 'min10float' is an invalid underlying type}} expected-warning {{min10float is promoted to min16float}} */
+  ZEROMIN10F,
+};
+
+int getValueFromMyEnum(MyEnum v) {                          /* expected-note {{candidate function not viable: no known conversion from 'MyEnum64' to 'MyEnum' for 1st argument}} expected-note {{candidate function not viable: no known conversion from 'MyEnumClass' to 'MyEnum' for 1st argument}} expected-note {{candidate function not viable: no known conversion from 'MyEnumInt' to 'MyEnum' for 1st argument}} expected-note {{candidate function not viable: no known conversion from 'MyEnumUInt' to 'MyEnum' for 1st argument}} expected-note {{candidate function not viable: no known conversion from 'literal int' to 'MyEnum' for 1st argument}} */
+  switch (v) {
+    case MyEnum::ZERO:
+      return 0;
+    case MyEnum::ONE:
+      return 1;
+    case MyEnum::TWO:
+      return 2;
+    case MyEnum::THREE:
+      return 3;
+    default:
+      return -1;
+  }
+}
+
+int getValueFromMyEnumClass(MyEnumClass v) {
+  switch (v) {
+    case MyEnumClass::ZEROC:
+      return 0;
+    case MyEnumClass::ONEC:
+      return 1;
+    case MyEnumClass::TWOC:
+      return 2;
+    case MyEnumClass::THREEC:
+      return 3;
+    default:
+      return -1;
+  }
+}
+
+int getValueFromInt(int i) {                                /* expected-note {{candidate function not viable: no known conversion from 'MyEnumClass' to 'int' for 1st argument}} */
+  switch (i) {
+    case 0:
+      return 0;
+    case 1:
+      return 1;
+    case 2:
+      return 2;
+    default:
+      return -1;
+  }
+}
+
+int4 main() : SV_Target {
+    int v0 = getValueFromInt(ZERO);
+    int v1 = getValueFromInt(MyEnumClass::ONEC); /* expected-error {{no matching function for call to 'getValueFromInt'}} */
+    int v2 = getValueFromInt(TWOI);
+    int v3 = getValueFromInt(THREEU);
+    int v4 = getValueFromInt(FOUR64);
+
+    int n0 = getValueFromMyEnum(ZERO);
+    int n1 = getValueFromMyEnum(MyEnumClass::ONEC); /* expected-error {{no matching function for call to 'getValueFromMyEnum'}} */
+    int n2 = getValueFromMyEnum(TWOI);              /* expected-error {{no matching function for call to 'getValueFromMyEnum'}} */
+    int n3 = getValueFromMyEnum(THREEU);            /* expected-error {{no matching function for call to 'getValueFromMyEnum'}} */
+    int n4 = getValueFromMyEnum(ZERO64);            /* expected-error {{no matching function for call to 'getValueFromMyEnum'}} */
+    int n5 = getValueFromMyEnum(2);                 /* expected-error {{no matching function for call to 'getValueFromMyEnum'}} */
+
+    int n6 = getValueFromMyEnumClass(MyEnumClass::ONEC);
+
+
+    MyEnum cast0 = (MyEnum) MyEnum::FOUR;
+    MyEnum cast1 = (MyEnum) MyEnumClass::THREEC;
+    MyEnum cast2 = (MyEnum) MyEnumInt::TWOI;
+    MyEnum cast3 = (MyEnum) MyEnum64::ONE64;
+    MyEnum cast4 = (MyEnum) MyEnumUInt::ZEROU;
+
+    MyEnum lst[4] = { ONE, TWO, TWO, THREE };
+
+    MyEnum unary0 = MyEnum::ZERO;
+    MyEnumClass unary1 = MyEnumClass::ZEROC;
+    unary0++;                                       /* expected-error {{cannot increment expression of enum type 'MyEnum'}} */
+    unary0--;                                       /* expected-error {{cannot decrement expression of enum type 'MyEnum'}} */
+    ++unary0;                                       /* expected-error {{cannot increment expression of enum type 'MyEnum'}} */
+    --unary0;                                       /* expected-error {{cannot decrement expression of enum type 'MyEnum'}} */
+    unary1++;                                       /* expected-error {{numeric type expected}} */
+    unary1--;                                       /* expected-error {{numeric type expected}} */
+    ++unary1;                                       /* expected-error {{numeric type expected}} */
+    --unary1;                                       /* expected-error {{numeric type expected}} */
+    
+    int unaryInt = !unary0;                                       
+    unaryInt = ~unary0;                                        
+    unaryInt = !unary1;                  /* expected-error {{numeric type expected}} */
+    unaryInt = ~unary1;                  /* expected-error {{int or unsigned int type required}} */
+
+    MyEnum castV = 1;                    /* expected-error {{cannot initialize a variable of type 'MyEnum' with an rvalue of type 'literal int'}} */
+    MyEnumInt castI = 10;                /* expected-error {{cannot initialize a variable of type 'MyEnumInt' with an rvalue of type 'literal int'}} */
+    MyEnumClass castC = 52;              /* expected-error {{cannot initialize a variable of type 'MyEnumClass' with an rvalue of type 'literal int'}} */
+    MyEnumUInt castU = 34;               /* expected-error {{cannot initialize a variable of type 'MyEnumUInt' with an rvalue of type 'literal int'}} */
+    MyEnum64 cast64 = 4037;              /* expected-error {{cannot initialize a variable of type 'MyEnum64' with an rvalue of type 'literal int'}} */
+
+    MyEnum MyEnum = MyEnum::ZERO;
+    MyEnumClass MyEnumClass = MyEnumClass::FOURC;
+    int i0 = MyEnum;
+    int i1 = MyEnum::ZERO;
+    int i2 = MyEnumClass;                 /* expected-error {{cannot initialize a variable of type 'int' with an lvalue of type 'MyEnumClass'}} */
+    int i3 = MyEnumClass::FOURC;          /* expected-error {{cannot initialize a variable of type 'int' with an rvalue of type 'MyEnumClass'}} */
+    float f0 = MyEnum;
+    float f1 = MyEnum::ZERO;
+    float f2 = MyEnumClass;               /* expected-error {{cannot initialize a variable of type 'float' with an lvalue of type 'MyEnumClass'}} */
+    float f3 = MyEnumClass::FOURC;        /* expected-error {{cannot initialize a variable of type 'float' with an rvalue of type 'MyEnumClass'}} */
+
+    int unaryD = THREE++;                /* expected-error {{cannot increment expression of enum type 'MyEnum'}} */
+    int unaryC = --MyEnumClass::FOURC;   /* expected-error {{expression is not assignable}} */
+    int unaryI = ++TWOI;                 /* expected-error {{cannot increment expression of enum type 'MyEnumInt'}} */
+    uint unaryU = ZEROU--;               /* expected-error {{cannot decrement expression of enum type 'MyEnumUInt'}} */
+    int unary64 = ++THREE64;             /* expected-error {{cannot increment expression of enum type 'MyEnum64'}} */
+
+
+    int Iadd = MyEnum::THREE - 48;
+    int IaddI = MyEnumInt::ZEROI + 3;
+    int IaddC = MyEnumClass::ONEC + 10; /* expected-error {{numeric type expected}} */
+    int IaddU = MyEnumUInt::TWOU + 15;
+    int Iadd64 = MyEnum64::THREE64 - 67;
+
+    float Fadd = MyEnum::ONE + 1.5f;
+    float FaddI = MyEnumInt::TWOI + 3.41f;
+    float FaddC = MyEnumClass::THREEC - 256.0f; /* expected-error {{numeric type expected}} */
+    float FaddU = MyEnumUInt::FOURU + 283.48f;
+    float Fadd64 = MyEnum64::ZERO64  - 8471.0f;
+
+    return 1;
+}

+ 3 - 0
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -2834,6 +2834,9 @@ public:
     }
     compiler.getLangOpts().RootSigMajor = 1;
     compiler.getLangOpts().RootSigMinor = rootSigMinor;
+    compiler.getLangOpts().HLSL2015 = Opts.HLSL2015;
+    compiler.getLangOpts().HLSL2016 = Opts.HLSL2016;
+    compiler.getLangOpts().HLSL2017 = Opts.HLSL2017;
 
     if (Opts.WarningAsError)
       compiler.getDiagnostics().setWarningsAsErrors(true);

+ 25 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -429,6 +429,11 @@ public:
   TEST_METHOD(CodeGenEliminateDynamicIndexing6)
   TEST_METHOD(CodeGenEmpty)
   TEST_METHOD(CodeGenEmptyStruct)
+  TEST_METHOD(CodeGenEnum1)
+  TEST_METHOD(CodeGenEnum2)
+  TEST_METHOD(CodeGenEnum3)
+  TEST_METHOD(CodeGenEnum4)
+  TEST_METHOD(CodeGenEnum5)
   TEST_METHOD(CodeGenEarlyDepthStencil)
   TEST_METHOD(CodeGenEval)
   TEST_METHOD(CodeGenEvalInvalid)
@@ -2464,6 +2469,26 @@ TEST_F(CompilerTest, CodeGenEmptyStruct) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\emptyStruct.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenEnum1) {
+    CodeGenTestCheck(L"..\\CodeGenHLSL\\enum1.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenEnum2) {
+    CodeGenTestCheck(L"..\\CodeGenHLSL\\enum2.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenEnum3) {
+    CodeGenTestCheck(L"..\\CodeGenHLSL\\enum3.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenEnum4) {
+    CodeGenTestCheck(L"..\\CodeGenHLSL\\enum4.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenEnum5) {
+    CodeGenTestCheck(L"..\\CodeGenHLSL\\enum5.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenEarlyDepthStencil) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\earlyDepthStencil.hlsl");
 }

+ 5 - 0
tools/clang/unittests/HLSL/VerifierTest.cpp

@@ -39,6 +39,7 @@ public:
   TEST_METHOD(RunConstAssign);
   TEST_METHOD(RunConstDefault);
   TEST_METHOD(RunCppErrors);
+  TEST_METHOD(RunEnums);
   TEST_METHOD(RunFunctions);
   TEST_METHOD(RunIndexingOperator);
   TEST_METHOD(RunIntrinsicExamples);
@@ -153,6 +154,10 @@ TEST_F(VerifierTest, RunCppErrors) {
   CheckVerifiesHLSL(L"cpp-errors.hlsl");
 }
 
+TEST_F(VerifierTest, RunEnums) {
+  CheckVerifiesHLSL(L"enums.hlsl");
+}
+
 TEST_F(VerifierTest, RunFunctions) {
   CheckVerifiesHLSL(L"functions.hlsl");
 }

+ 1 - 0
utils/hct/VerifierHelper.py

@@ -50,6 +50,7 @@ HlslBinDir = os.path.expandvars(r'${HLSL_BLD_DIR}\Debug\bin')
 VerifierTests = {
     'RunAttributes': "attributes.hlsl",
 #    'RunCppErrors': "cpp-errors.hlsl",             # This test doesn't work properly in HLSL (fxc mode)
+    'RunEnums' : "enums.hlsl",
     'RunIndexingOperator': "indexing-operator.hlsl",
     'RunIntrinsicExamples': "intrinsic-examples.hlsl",
     'RunMatrixAssignments': "matrix-assignments.hlsl",