Przeglądaj źródła

Support pragma pack_matrix. (#1623)

Xiang Li 6 lat temu
rodzic
commit
536240b3b8

+ 1 - 0
tools/clang/include/clang/Parse/Parser.h

@@ -169,6 +169,7 @@ class Parser : public CodeCompletionHandler {
   std::unique_ptr<PragmaHandler> LoopHintHandler;
   std::unique_ptr<PragmaHandler> UnrollHintHandler;
   std::unique_ptr<PragmaHandler> NoUnrollHintHandler;
+  std::unique_ptr<PragmaHandler> PackMatrixHandler; // HLSL Change -packmatrix.
 
   std::unique_ptr<CommentHandler> CommentSemaHandler;
 

+ 9 - 0
tools/clang/include/clang/Sema/Sema.h

@@ -333,6 +333,12 @@ public:
   LangOptions::PragmaMSPointersToMembersKind
       MSPointerToMemberRepresentationMethod;
 
+  // HLSL Change Begin - pragma pack_matrix.
+  // Add both row/col to identify the default case which no pragma.
+  bool PackMatrixRowMajorPragmaOn = false; // True when \#pragma pack_matrix(row_major) on.
+  bool PackMatrixColMajorPragmaOn = false; // True when \#pragma pack_matrix(column_major) on.
+  // HLSL Change End.
+
   enum PragmaVtorDispKind {
     PVDK_Push,          ///< #pragma vtordisp(push, mode)
     PVDK_Set,           ///< #pragma vtordisp(mode)
@@ -7554,6 +7560,9 @@ public:
                        SourceLocation LParenLoc,
                        SourceLocation RParenLoc);
 
+  /// ActOnPragmaPackMatrix - Called on well formed \#pragma pack_matrix(...).
+  void ActOnPragmaPackMatrix(bool bRowMajor, SourceLocation PragmaLoc);
+
   /// ActOnPragmaMSStruct - Called on well formed \#pragma ms_struct [on|off].
   void ActOnPragmaMSStruct(PragmaMSStructKind Kind);
 

+ 66 - 0
tools/clang/lib/Parse/ParsePragma.cpp

@@ -156,6 +156,15 @@ struct PragmaUnrollHintHandler : public PragmaHandler {
                     Token &FirstToken) override;
 };
 
+struct PragmaPackMatrixHandler : public PragmaHandler {
+  PragmaPackMatrixHandler(Sema &S) : PragmaHandler("pack_matrix"), Actions(S) {}
+  void HandlePragma(Preprocessor &PP, PragmaIntroducerKind Introducer,
+                    Token &FirstToken) override;
+
+private:
+  Sema &Actions;
+};
+
 }  // end namespace
 
 void Parser::initializePragmaHandlers() {
@@ -240,6 +249,12 @@ void Parser::initializePragmaHandlers() {
   NoUnrollHintHandler.reset(new PragmaUnrollHintHandler("nounroll"));
   PP.AddPragmaHandler(NoUnrollHintHandler.get());
   } // HLSL Change, matching HLSL check to remove pragma processing
+  else {
+    // HLSL Change Begin - packmatrix.
+    PackMatrixHandler.reset(new PragmaPackMatrixHandler(Actions));
+    PP.AddPragmaHandler(PackMatrixHandler.get());
+    // HLSL Change End.
+  }
 }
 
 void Parser::resetPragmaHandlers() {
@@ -311,6 +326,12 @@ void Parser::resetPragmaHandlers() {
   PP.RemovePragmaHandler(NoUnrollHintHandler.get());
   NoUnrollHintHandler.reset();
   } // HLSL Change - close conditional for skipping pragmas
+  else {
+    // HLSL Change Begin - packmatrix.
+    PP.RemovePragmaHandler(PackMatrixHandler.get());
+    PackMatrixHandler.reset();
+    // HLSL Change End.
+  }
 }
 
 /// \brief Handle the annotation token produced for #pragma unused(...)
@@ -2165,3 +2186,48 @@ void PragmaUnrollHintHandler::HandlePragma(Preprocessor &PP,
   PP.EnterTokenStream(TokenArray, 1, /*DisableMacroExpansion=*/false,
                       /*OwnsTokens=*/true);
 }
+
+// HLSL Change Begin - pack_matrix
+/// \brief Handle the pack_matrix pragmas.
+///  #pragma pack_matrix(row_major)
+///  #pragma pack_matrix(column_major)
+///
+void PragmaPackMatrixHandler::HandlePragma(Preprocessor &PP,
+                                           PragmaIntroducerKind Introducer,
+                                           Token &Tok) {
+  assert(PP.getLangOpts().HLSL && "only supported in HLSL");
+  Token PragmaName = Tok;
+  PP.Lex(Tok);
+  if (!Tok.is(tok::l_paren)) {
+    PP.Diag(Tok, diag::err_expected) << tok::l_brace;
+    return;
+  }
+
+  PP.Lex(Tok);
+  Token PragmaArg = Tok;
+  bool bRowMajor = false;
+  if (Tok.is(tok::kw_row_major)) {
+    bRowMajor = true;
+  }
+  else if (Tok.isNot(tok::kw_column_major)) {
+    PP.Diag(Tok.getLocation(), diag::err_pragma_invalid_keyword);
+    return;
+  }
+  // Make sure pragma finish correctly.
+  PP.Lex(Tok);
+  if (Tok.isNot(tok::r_paren)) {
+    PP.Diag(Tok, diag::err_expected) << tok::r_brace;
+    return;
+  }
+  PP.Lex(Tok);
+  if (Tok.isNot(tok::eod)) {
+    PP.Diag(Tok.getLocation(), diag::warn_pragma_extra_tokens_at_eol);
+    return;
+  }
+  // Note: to make things easy, pack_matrix will modify ast type directly in
+  // Sema::TransferUnusualAttributes.
+  // Another solution is create ast node for pack_matrix, and take care it at
+  // clang codegen.
+  Actions.ActOnPragmaPackMatrix(bRowMajor, PragmaArg.getLocation());
+}
+// HLSL Change End.

+ 10 - 0
tools/clang/lib/Sema/SemaAttr.cpp

@@ -265,6 +265,16 @@ void Sema::ActOnPragmaPack(PragmaPackKind Kind, IdentifierInfo *Name,
   }
 }
 
+void Sema::ActOnPragmaPackMatrix(bool bRowMajor, SourceLocation PragmaLoc) {
+  if (bRowMajor) {
+    PackMatrixRowMajorPragmaOn = true;
+    PackMatrixColMajorPragmaOn = false;
+  } else {
+    PackMatrixRowMajorPragmaOn = false;
+    PackMatrixColMajorPragmaOn = true;
+  }
+}
+
 void Sema::ActOnPragmaMSStruct(PragmaMSStructKind Kind) { 
   MSStructPragmaOn = (Kind == PMSST_ON);
 }

+ 50 - 0
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -11061,6 +11061,56 @@ void Sema::TransferUnusualAttributes(Declarator &D, NamedDecl *NewDecl) {
         D.UnusualAnnotations.size()));
     D.UnusualAnnotations.clear();
   }
+  // pragma pack_matrix.
+  // Do this for struct member also.
+  if (ValueDecl *VD = dyn_cast<ValueDecl>(NewDecl)) {
+    QualType Ty = VD->getType();
+    QualType EltTy = Ty;
+    while (EltTy->isArrayType()) {
+      EltTy = EltTy->getAsArrayTypeUnsafe()->getElementType();
+    }
+    if (hlsl::IsHLSLMatType(EltTy)) {
+      bool bRowMajor = false;
+      if (!hlsl::HasHLSLMatOrientation(EltTy, &bRowMajor)) {
+        if (PackMatrixColMajorPragmaOn || PackMatrixRowMajorPragmaOn) {
+          // Add major.
+          QualType NewEltTy = Context.getAttributedType(
+              PackMatrixRowMajorPragmaOn
+                  ? AttributedType::attr_hlsl_row_major
+                  : AttributedType::attr_hlsl_column_major,
+              EltTy, EltTy);
+
+          QualType NewTy = NewEltTy;
+          if (Ty->isArrayType()) {
+            // Build new array type.
+            SmallVector<const ArrayType *, 2> arrayTys;
+            while (EltTy->isArrayType()) {
+              const ArrayType *AT = EltTy->getAsArrayTypeUnsafe();
+              arrayTys.emplace_back(AT);
+            }
+            for (auto rit = arrayTys.rbegin(); rit != arrayTys.rend(); rit++) {
+              // Create array type with NewTy.
+              const ArrayType *AT = *rit;
+              if (const ConstantArrayType *CAT =
+                      dyn_cast<ConstantArrayType>(AT)) {
+                NewTy = Context.getConstantArrayType(
+                    NewTy, CAT->getSize(), CAT->getSizeModifier(),
+                    CAT->getIndexTypeCVRQualifiers());
+              } else if (const IncompleteArrayType *IAT =
+                             dyn_cast<IncompleteArrayType>(AT)) {
+                NewTy = Context.getIncompleteArrayType(NewTy, IAT->getSizeModifier(),
+                    IAT->getIndexTypeCVRQualifiers());
+              } else {
+                DXASSERT(false, "");
+              }
+            }
+          }
+          // Update Type.
+          VD->setType(NewTy);
+        }
+      }
+    }
+  }
 }
 
 /// Checks whether a usage attribute is compatible with those seen so far and

+ 35 - 0
tools/clang/test/CodeGenHLSL/quick-test/pack_matrix.hlsl

@@ -0,0 +1,35 @@
+// RUN: %dxc -E main -T ps_6_0 -ast-dump %s  | FileCheck %s
+
+// CHECK:row_major
+#pragma pack_matrix(row_major)
+
+struct Foo
+{
+  float2x2 a;
+};
+
+// CHECK:column_major
+#pragma pack_matrix(column_major)
+
+struct Bar {
+  float2x2 a;
+};
+
+Foo f;
+Bar b;
+
+// CHECK:row_major
+#pragma pack_matrix(row_major)
+
+float2x2 c;
+
+// CHECK:column_major
+#pragma pack_matrix(column_major)
+float2x2 d;
+
+// CHECK: main 'float4 ()'
+float4 main() : SV_Target
+{
+  float2x2 e = f.a + b.a + c + d;
+  return e;
+}