|
|
@@ -0,0 +1,377 @@
|
|
|
+// Copyright (c) 2024 Google LLC
|
|
|
+//
|
|
|
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
|
+// See https://llvm.org/LICENSE.txt for license information.
|
|
|
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
+
|
|
|
+#include "vk/opcode_selector.h"
|
|
|
+
|
|
|
+template <typename ResultType, typename ComponentType>
|
|
|
+[[vk::ext_instruction(/* OpMatrixTimesScalar */ 143)]] ResultType
|
|
|
+__builtin_spv_MatrixTimesScalar(ResultType a, ComponentType b);
|
|
|
+
|
|
|
+template <typename ComponentType, vk::Scope scope, uint rows, uint columns,
|
|
|
+ vk::CooperativeMatrixUse use>
|
|
|
+[[vk::ext_instruction(/* OpCompositeExtract */ 81)]] ComponentType
|
|
|
+__builtin_spv_ExtractFromCooperativeMatrix(
|
|
|
+ typename vk::khr::CooperativeMatrix<ComponentType, scope, rows, columns,
|
|
|
+ use>::SpirvMatrixType matrix,
|
|
|
+ uint32_t index);
|
|
|
+
|
|
|
+template <typename CoopMatrixType, typename ComponentType>
|
|
|
+[[vk::ext_instruction(/* OpCompositeConstruct */ 80)]] CoopMatrixType
|
|
|
+__builtin_spv_ConstructCooperativeMatrix(ComponentType value);
|
|
|
+
|
|
|
+template <class ResultPointerType, class BaseType>
|
|
|
+[[vk::ext_instruction(/* OpAccessChain */ 65)]] ResultPointerType
|
|
|
+__builtin_spv_AccessChain([[vk::ext_reference]] BaseType base, uint32_t index);
|
|
|
+
|
|
|
+template <class ObjectType, class PointerType>
|
|
|
+[[vk::ext_instruction(/* OpLoad */ 61)]] ObjectType
|
|
|
+__builtin_spv_LoadPointer(PointerType base);
|
|
|
+
|
|
|
+template <class PointerType, class ObjectType>
|
|
|
+[[vk::ext_instruction(/* OpLoad */ 62)]] void
|
|
|
+__builtin_spv_StorePointer(PointerType base, ObjectType object);
|
|
|
+
|
|
|
+template <typename ComponentType, vk::Scope scope, uint rows, uint columns,
|
|
|
+ vk::CooperativeMatrixUse use>
|
|
|
+[[vk::ext_instruction(/* OpCompositeInsert */ 82)]]
|
|
|
+typename vk::khr::CooperativeMatrix<ComponentType, scope, rows, columns,
|
|
|
+ use>::SpirvMatrixType
|
|
|
+__builtin_spv_InsertIntoCooperativeMatrix(
|
|
|
+ ComponentType value,
|
|
|
+ typename vk::khr::CooperativeMatrix<ComponentType, scope, rows, columns,
|
|
|
+ use>::SpirvMatrixType matrix,
|
|
|
+ uint32_t index);
|
|
|
+
|
|
|
+// Define the load and store instructions
|
|
|
+template <typename ResultType, typename PointerType>
|
|
|
+[[vk::ext_instruction(/* OpCooperativeMatrixLoadKHR */ 4457)]] ResultType
|
|
|
+__builtin_spv_CooperativeMatrixLoadKHR(
|
|
|
+ [[vk::ext_reference]] PointerType pointer,
|
|
|
+ vk::CooperativeMatrixLayout memory_layout, uint stride,
|
|
|
+ [[vk::ext_literal]] uint32_t memory_operand);
|
|
|
+
|
|
|
+template <typename ResultType, typename PointerType>
|
|
|
+[[vk::ext_instruction(/* OpCooperativeMatrixLoadKHR */ 4457)]] ResultType
|
|
|
+__builtin_spv_CooperativeMatrixLoadKHR(
|
|
|
+ [[vk::ext_reference]] PointerType pointer,
|
|
|
+ vk::CooperativeMatrixLayout memory_layout, uint stride,
|
|
|
+ [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope);
|
|
|
+
|
|
|
+template <typename ResultType, typename PointerType>
|
|
|
+[[vk::ext_instruction(/* OpCooperativeMatrixLoadKHR */ 4457)]] ResultType
|
|
|
+__builtin_spv_CooperativeMatrixWorkgroupLoadKHR(
|
|
|
+ vk::WorkgroupSpirvPointer<PointerType> pointer,
|
|
|
+ vk::CooperativeMatrixLayout memory_layout, uint stride,
|
|
|
+ [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope);
|
|
|
+
|
|
|
+template <typename ObjectType, typename PointerType>
|
|
|
+[[vk::ext_instruction(/* OpCooperativeMatrixStoreKHR */ 4458)]] void
|
|
|
+__builtin_spv_CooperativeMatrixStoreKHR(
|
|
|
+ [[vk::ext_reference]] PointerType pointer, ObjectType object,
|
|
|
+ vk::CooperativeMatrixLayout memory_layout, uint stride,
|
|
|
+ [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope);
|
|
|
+
|
|
|
+template <typename ObjectType, typename PointerType>
|
|
|
+[[vk::ext_instruction(/* OpCooperativeMatrixStoreKHR */ 4458)]] void
|
|
|
+__builtin_spv_CooperativeMatrixStoreKHR(
|
|
|
+ [[vk::ext_reference]] PointerType pointer, ObjectType object,
|
|
|
+ vk::CooperativeMatrixLayout memory_layout, uint stride,
|
|
|
+ [[vk::ext_literal]] uint32_t memory_operand);
|
|
|
+
|
|
|
+template <typename ObjectType, typename PointerType>
|
|
|
+[[vk::ext_instruction(/* OpCooperativeMatrixStoreKHR */ 4458)]] void
|
|
|
+__builtin_spv_CooperativeMatrixWorkgroupStoreKHR(
|
|
|
+ vk::WorkgroupSpirvPointer<PointerType> pointer, ObjectType object,
|
|
|
+ vk::CooperativeMatrixLayout memory_layout, uint stride,
|
|
|
+ [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope);
|
|
|
+
|
|
|
+// We cannot define `OpCooperativeMatrixLengthKHR` using ext_instruction because
|
|
|
+// one of the operands is a type id. This builtin will have specific code in the
|
|
|
+// compiler to expand it.
|
|
|
+template <class MatrixType> uint __builtin_spv_CooperativeMatrixLengthKHR();
|
|
|
+
|
|
|
+// Arithmetic Instructions
|
|
|
+template <typename ResultType, typename MatrixTypeA, typename MatrixTypeB,
|
|
|
+ typename MatrixTypeC>
|
|
|
+[[vk::ext_instruction(/* OpCooperativeMatrixMulAddKHR */ 4459)]] ResultType
|
|
|
+__builtin_spv_CooperativeMatrixMulAddKHR(MatrixTypeA a, MatrixTypeB b,
|
|
|
+ MatrixTypeC c,
|
|
|
+ [[vk::ext_literal]] int operands);
|
|
|
+namespace vk {
|
|
|
+namespace khr {
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+template <class NewComponentType>
|
|
|
+CooperativeMatrix<NewComponentType, scope, rows, columns, use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>::cast() {
|
|
|
+ using ResultType =
|
|
|
+ CooperativeMatrix<NewComponentType, scope, rows, columns, use>;
|
|
|
+ ResultType result;
|
|
|
+ result._matrix = util::ConversionSelector<ComponentType, NewComponentType>::
|
|
|
+ template Convert<typename ResultType::SpirvMatrixType>(_matrix);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>::negate() {
|
|
|
+ CooperativeMatrix result;
|
|
|
+ result._matrix = util::ArithmeticSelector<ComponentType>::Negate(_matrix);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator+(
|
|
|
+ CooperativeMatrix other) {
|
|
|
+ CooperativeMatrix result;
|
|
|
+ result._matrix =
|
|
|
+ util::ArithmeticSelector<ComponentType>::Add(_matrix, other._matrix);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator-(
|
|
|
+ CooperativeMatrix other) {
|
|
|
+ CooperativeMatrix result;
|
|
|
+ result._matrix =
|
|
|
+ util::ArithmeticSelector<ComponentType>::Sub(_matrix, other._matrix);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator*(
|
|
|
+ CooperativeMatrix other) {
|
|
|
+ CooperativeMatrix result;
|
|
|
+ result._matrix =
|
|
|
+ util::ArithmeticSelector<ComponentType>::Mul(_matrix, other._matrix);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator/(
|
|
|
+ CooperativeMatrix other) {
|
|
|
+ CooperativeMatrix result;
|
|
|
+ result._matrix =
|
|
|
+ util::ArithmeticSelector<ComponentType>::Div(_matrix, other._matrix);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator*(
|
|
|
+ ComponentType scalar) {
|
|
|
+ CooperativeMatrix result;
|
|
|
+ result._matrix = __builtin_spv_MatrixTimesScalar(_matrix, scalar);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
|
|
|
+ class Type>
|
|
|
+void CooperativeMatrix<ComponentType, scope, rows, columns, use>::Store(
|
|
|
+ WorkgroupSpirvPointer<Type> data, uint32_t stride) {
|
|
|
+ __builtin_spv_CooperativeMatrixWorkgroupStoreKHR(
|
|
|
+ data, _matrix, layout, stride,
|
|
|
+ memoryAccessOperands | MemoryAccessNonPrivatePointerMask |
|
|
|
+ MemoryAccessMakePointerAvailableMask,
|
|
|
+ ScopeWorkgroup);
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
|
|
|
+ class Type>
|
|
|
+void CooperativeMatrix<ComponentType, scope, rows, columns, use>::Store(
|
|
|
+ RWStructuredBuffer<Type> data, uint32_t index, uint32_t stride) {
|
|
|
+ __builtin_spv_CooperativeMatrixStoreKHR(data[index], _matrix, layout, stride,
|
|
|
+ memoryAccessOperands);
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
|
|
|
+ class Type>
|
|
|
+void CooperativeMatrix<ComponentType, scope, rows, columns, use>::CoherentStore(
|
|
|
+ globallycoherent RWStructuredBuffer<Type> data, uint32_t index,
|
|
|
+ uint32_t stride) {
|
|
|
+ __builtin_spv_CooperativeMatrixStoreKHR(
|
|
|
+ data[index], _matrix, layout, stride,
|
|
|
+ memoryAccessOperands | MemoryAccessNonPrivatePointerMask |
|
|
|
+ MemoryAccessMakePointerAvailableMask,
|
|
|
+ ScopeQueueFamily);
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
|
|
|
+ class Type>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>::Load(
|
|
|
+ vk::WorkgroupSpirvPointer<Type> buffer, uint32_t stride) {
|
|
|
+ CooperativeMatrix result;
|
|
|
+ result._matrix =
|
|
|
+ __builtin_spv_CooperativeMatrixWorkgroupLoadKHR<SpirvMatrixType>(
|
|
|
+ buffer, layout, stride,
|
|
|
+ memoryAccessOperands | MemoryAccessNonPrivatePointerMask |
|
|
|
+ MemoryAccessMakePointerVisibleMask,
|
|
|
+ ScopeWorkgroup);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
|
|
|
+ class Type>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>::Load(
|
|
|
+ RWStructuredBuffer<Type> buffer, uint32_t index, uint32_t stride) {
|
|
|
+ CooperativeMatrix result;
|
|
|
+ result._matrix = __builtin_spv_CooperativeMatrixLoadKHR<SpirvMatrixType>(
|
|
|
+ buffer[index], layout, stride, memoryAccessOperands);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
|
|
|
+ class Type>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>::CoherentLoad(
|
|
|
+ RWStructuredBuffer<Type> buffer, uint32_t index, uint32_t stride) {
|
|
|
+ CooperativeMatrix result;
|
|
|
+ result._matrix = __builtin_spv_CooperativeMatrixLoadKHR<SpirvMatrixType>(
|
|
|
+ buffer[index], layout, stride,
|
|
|
+ memoryAccessOperands | MemoryAccessNonPrivatePointerMask |
|
|
|
+ MemoryAccessMakePointerVisibleMask,
|
|
|
+ ScopeQueueFamily);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
|
|
|
+ class Type>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>::Load(
|
|
|
+ StructuredBuffer<Type> buffer, uint32_t index, uint32_t stride) {
|
|
|
+ CooperativeMatrix result;
|
|
|
+ result._matrix = __builtin_spv_CooperativeMatrixLoadKHR<SpirvMatrixType>(
|
|
|
+ buffer[index], layout, stride, MemoryAccessMaskNone);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>
|
|
|
+CooperativeMatrix<ComponentType, scope, rows, columns, use>::Splat(
|
|
|
+ ComponentType v) {
|
|
|
+ CooperativeMatrix result;
|
|
|
+ result._matrix = __builtin_spv_ConstructCooperativeMatrix<SpirvMatrixType>(v);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+uint CooperativeMatrix<ComponentType, scope, rows, columns, use>::GetLength() {
|
|
|
+ return __builtin_spv_CooperativeMatrixLengthKHR<SpirvMatrixType>();
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+ComponentType CooperativeMatrix<ComponentType, scope, rows, columns, use>::Get(
|
|
|
+ uint32_t index) {
|
|
|
+ // clang-format off
|
|
|
+ using ComponentPtr = vk::SpirvOpaqueType<
|
|
|
+ /* OpTypePointer */ 32,
|
|
|
+ /* function storage class */ vk::Literal<vk::integral_constant<uint, 7> >,
|
|
|
+ ComponentType>;
|
|
|
+ // clang-format on
|
|
|
+ ComponentPtr ptr = __builtin_spv_AccessChain<ComponentPtr>(_matrix, index);
|
|
|
+ return __builtin_spv_LoadPointer<ComponentType>(ptr);
|
|
|
+}
|
|
|
+
|
|
|
+template <class ComponentType, Scope scope, uint rows, uint columns,
|
|
|
+ CooperativeMatrixUse use>
|
|
|
+void CooperativeMatrix<ComponentType, scope, rows, columns, use>::Set(
|
|
|
+ ComponentType value, uint32_t index) {
|
|
|
+ // clang-format off
|
|
|
+ using ComponentPtr = vk::SpirvOpaqueType<
|
|
|
+ /* OpTypePointer */ 32,
|
|
|
+ /* function storage class */ vk::Literal<vk::integral_constant<uint, 7> >,
|
|
|
+ ComponentType>;
|
|
|
+ // clang-format on
|
|
|
+ ComponentPtr ptr = __builtin_spv_AccessChain<ComponentPtr>(_matrix, index);
|
|
|
+ return __builtin_spv_StorePointer(ptr, value);
|
|
|
+}
|
|
|
+
|
|
|
+template <typename ComponentType, Scope scope, uint rows, uint columns, uint K>
|
|
|
+CooperativeMatrixAccumulator<ComponentType, scope, rows, columns>
|
|
|
+cooperativeMatrixMultiplyAdd(
|
|
|
+ CooperativeMatrixA<ComponentType, scope, rows, K> a,
|
|
|
+ CooperativeMatrixB<ComponentType, scope, K, columns> b,
|
|
|
+ CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> c) {
|
|
|
+
|
|
|
+ const vk::CooperativeMatrixOperandsMask allSignedComponents =
|
|
|
+ vk::CooperativeMatrixOperandsMatrixASignedComponentsKHRMask |
|
|
|
+ vk::CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask |
|
|
|
+ vk::CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask |
|
|
|
+ vk::CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask;
|
|
|
+
|
|
|
+ const vk::CooperativeMatrixOperandsMask operands =
|
|
|
+ (vk::CooperativeMatrixOperandsMask)(
|
|
|
+ a.hasSignedIntegerComponentType
|
|
|
+ ? allSignedComponents
|
|
|
+ : vk::CooperativeMatrixOperandsMaskNone);
|
|
|
+
|
|
|
+ CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> result;
|
|
|
+ result._matrix = __builtin_spv_CooperativeMatrixMulAddKHR<
|
|
|
+ typename CooperativeMatrixAccumulator<ComponentType, scope, rows,
|
|
|
+ columns>::SpirvMatrixType>(
|
|
|
+ a._matrix, b._matrix, c._matrix, operands);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+template <typename ComponentType, Scope scope, uint rows, uint columns, uint K>
|
|
|
+CooperativeMatrixAccumulator<ComponentType, scope, rows, columns>
|
|
|
+cooperativeMatrixSaturatingMultiplyAdd(
|
|
|
+ CooperativeMatrixA<ComponentType, scope, rows, K> a,
|
|
|
+ CooperativeMatrixB<ComponentType, scope, K, columns> b,
|
|
|
+ CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> c) {
|
|
|
+
|
|
|
+ const vk::CooperativeMatrixOperandsMask allSignedComponents =
|
|
|
+ vk::CooperativeMatrixOperandsMatrixASignedComponentsKHRMask |
|
|
|
+ vk::CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask |
|
|
|
+ vk::CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask |
|
|
|
+ vk::CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask |
|
|
|
+ vk::CooperativeMatrixOperandsSaturatingAccumulationKHRMask;
|
|
|
+
|
|
|
+ const vk::CooperativeMatrixOperandsMask operands =
|
|
|
+ (vk::CooperativeMatrixOperandsMask)(
|
|
|
+ a.hasSignedIntegerComponentType
|
|
|
+ ? allSignedComponents
|
|
|
+ : vk::CooperativeMatrixOperandsSaturatingAccumulationKHRMask);
|
|
|
+ CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> result;
|
|
|
+ result._matrix = __builtin_spv_CooperativeMatrixMulAddKHR<
|
|
|
+ typename CooperativeMatrixAccumulator<ComponentType, scope, rows,
|
|
|
+ columns>::SpirvMatrixType>(
|
|
|
+ a._matrix, b._matrix, c._matrix, operands);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+} // namespace khr
|
|
|
+} // namespace vk
|