|
|
@@ -197,17 +197,49 @@ bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage,
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
+bool ContainsCooperativeMatrix(ValidationState_t& _,
|
|
|
+ const Instruction* storage) {
|
|
|
+ const size_t elem_type_index = 1;
|
|
|
+ uint32_t elem_type_id;
|
|
|
+ Instruction* elem_type;
|
|
|
+
|
|
|
+ switch (storage->opcode()) {
|
|
|
+ case SpvOpTypeCooperativeMatrixNV:
|
|
|
+ return true;
|
|
|
+ case SpvOpTypeArray:
|
|
|
+ case SpvOpTypeRuntimeArray:
|
|
|
+ elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
|
|
|
+ elem_type = _.FindDef(elem_type_id);
|
|
|
+ return ContainsCooperativeMatrix(_, elem_type);
|
|
|
+ case SpvOpTypeStruct:
|
|
|
+ for (size_t member_type_index = 1;
|
|
|
+ member_type_index < storage->operands().size();
|
|
|
+ ++member_type_index) {
|
|
|
+ auto member_type_id =
|
|
|
+ storage->GetOperandAs<uint32_t>(member_type_index);
|
|
|
+ auto member_type = _.FindDef(member_type_id);
|
|
|
+ if (ContainsCooperativeMatrix(_, member_type)) return true;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+}
|
|
|
+
|
|
|
std::pair<SpvStorageClass, SpvStorageClass> GetStorageClass(
|
|
|
ValidationState_t& _, const Instruction* inst) {
|
|
|
SpvStorageClass dst_sc = SpvStorageClassMax;
|
|
|
SpvStorageClass src_sc = SpvStorageClassMax;
|
|
|
switch (inst->opcode()) {
|
|
|
+ case SpvOpCooperativeMatrixLoadNV:
|
|
|
case SpvOpLoad: {
|
|
|
auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
|
|
|
auto load_pointer_type = _.FindDef(load_pointer->type_id());
|
|
|
dst_sc = load_pointer_type->GetOperandAs<SpvStorageClass>(1);
|
|
|
break;
|
|
|
}
|
|
|
+ case SpvOpCooperativeMatrixStoreNV:
|
|
|
case SpvOpStore: {
|
|
|
auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
|
|
|
auto store_pointer_type = _.FindDef(store_pointer->type_id());
|
|
|
@@ -232,7 +264,8 @@ std::pair<SpvStorageClass, SpvStorageClass> GetStorageClass(
|
|
|
}
|
|
|
|
|
|
// This function is only called for OpLoad, OpStore, OpCopyMemory and
|
|
|
-// OpCopyMemorySized.
|
|
|
+// OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
|
|
|
+// OpCooperativeMatrixStoreNV.
|
|
|
uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask) {
|
|
|
uint32_t offset = 1;
|
|
|
if (mask & SpvMemoryAccessAlignedMask) ++offset;
|
|
|
@@ -245,6 +278,10 @@ uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask) {
|
|
|
case SpvOpStore:
|
|
|
case SpvOpCopyMemory:
|
|
|
return inst->GetOperandAs<uint32_t>(2 + offset);
|
|
|
+ case SpvOpCooperativeMatrixLoadNV:
|
|
|
+ return inst->GetOperandAs<uint32_t>(5 + offset);
|
|
|
+ case SpvOpCooperativeMatrixStoreNV:
|
|
|
+ return inst->GetOperandAs<uint32_t>(4 + offset);
|
|
|
default:
|
|
|
assert(false && "unexpected opcode");
|
|
|
break;
|
|
|
@@ -253,8 +290,9 @@ uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask) {
|
|
|
return scope_id;
|
|
|
}
|
|
|
|
|
|
-// This function is only called for OpLoad, OpStore, OpCopyMemory and
|
|
|
-// OpCopyMemorySized.
|
|
|
+// This function is only called for OpLoad, OpStore, OpCopyMemory,
|
|
|
+// OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
|
|
|
+// OpCooperativeMatrixStoreNV.
|
|
|
uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask) {
|
|
|
uint32_t offset = 1;
|
|
|
if (mask & SpvMemoryAccessAlignedMask) ++offset;
|
|
|
@@ -268,6 +306,10 @@ uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask) {
|
|
|
case SpvOpStore:
|
|
|
case SpvOpCopyMemory:
|
|
|
return inst->GetOperandAs<uint32_t>(2 + offset);
|
|
|
+ case SpvOpCooperativeMatrixLoadNV:
|
|
|
+ return inst->GetOperandAs<uint32_t>(5 + offset);
|
|
|
+ case SpvOpCooperativeMatrixStoreNV:
|
|
|
+ return inst->GetOperandAs<uint32_t>(4 + offset);
|
|
|
default:
|
|
|
assert(false && "unexpected opcode");
|
|
|
break;
|
|
|
@@ -302,7 +344,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
|
|
|
|
|
|
uint32_t mask = inst->GetOperandAs<uint32_t>(index);
|
|
|
if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) {
|
|
|
- if (inst->opcode() == SpvOpLoad) {
|
|
|
+ if (inst->opcode() == SpvOpLoad ||
|
|
|
+ inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
|
|
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
|
|
<< "MakePointerAvailableKHR cannot be used with OpLoad.";
|
|
|
}
|
|
|
@@ -320,7 +363,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
|
|
|
}
|
|
|
|
|
|
if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) {
|
|
|
- if (inst->opcode() == SpvOpStore) {
|
|
|
+ if (inst->opcode() == SpvOpStore ||
|
|
|
+ inst->opcode() == SpvOpCooperativeMatrixStoreNV) {
|
|
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
|
|
<< "MakePointerVisibleKHR cannot be used with OpStore.";
|
|
|
}
|
|
|
@@ -672,6 +716,17 @@ spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ // Cooperative matrix types can only be allocated in Function or Private
|
|
|
+ if ((storage_class != SpvStorageClassFunction &&
|
|
|
+ storage_class != SpvStorageClassPrivate) &&
|
|
|
+ ContainsCooperativeMatrix(_, pointee)) {
|
|
|
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
|
|
|
+ << "Cooperative matrix types (or types containing them) can only be "
|
|
|
+ "allocated "
|
|
|
+ << "in Function or Private storage classes or as function "
|
|
|
+ "parameters";
|
|
|
+ }
|
|
|
+
|
|
|
return SPV_SUCCESS;
|
|
|
}
|
|
|
|
|
|
@@ -1003,10 +1058,11 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
|
|
|
switch (type_pointee->opcode()) {
|
|
|
case SpvOpTypeMatrix:
|
|
|
case SpvOpTypeVector:
|
|
|
+ case SpvOpTypeCooperativeMatrixNV:
|
|
|
case SpvOpTypeArray:
|
|
|
case SpvOpTypeRuntimeArray: {
|
|
|
- // In OpTypeMatrix, OpTypeVector, OpTypeArray, and OpTypeRuntimeArray,
|
|
|
- // word 2 is the Element Type.
|
|
|
+ // In OpTypeMatrix, OpTypeVector, SpvOpTypeCooperativeMatrixNV,
|
|
|
+ // OpTypeArray, and OpTypeRuntimeArray, word 2 is the Element Type.
|
|
|
type_pointee = _.FindDef(type_pointee->word(2));
|
|
|
break;
|
|
|
}
|
|
|
@@ -1136,6 +1192,140 @@ spv_result_t ValidateArrayLength(ValidationState_t& state,
|
|
|
return SPV_SUCCESS;
|
|
|
}
|
|
|
|
|
|
+spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
|
|
|
+ const Instruction* inst) {
|
|
|
+ std::string instr_name =
|
|
|
+ "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
|
|
|
+
|
|
|
+ // Result type must be a 32-bit unsigned int.
|
|
|
+ auto result_type = state.FindDef(inst->type_id());
|
|
|
+ if (result_type->opcode() != SpvOpTypeInt ||
|
|
|
+ result_type->GetOperandAs<uint32_t>(1) != 32 ||
|
|
|
+ result_type->GetOperandAs<uint32_t>(2) != 0) {
|
|
|
+ return state.diag(SPV_ERROR_INVALID_ID, inst)
|
|
|
+ << "The Result Type of " << instr_name << " <id> '"
|
|
|
+ << state.getIdName(inst->id())
|
|
|
+ << "' must be OpTypeInt with width 32 and signedness 0.";
|
|
|
+ }
|
|
|
+
|
|
|
+ auto type_id = inst->GetOperandAs<uint32_t>(2);
|
|
|
+ auto type = state.FindDef(type_id);
|
|
|
+ if (type->opcode() != SpvOpTypeCooperativeMatrixNV) {
|
|
|
+ return state.diag(SPV_ERROR_INVALID_ID, inst)
|
|
|
+ << "The type in " << instr_name << " <id> '"
|
|
|
+ << state.getIdName(type_id)
|
|
|
+ << "' must be OpTypeCooperativeMatrixNV.";
|
|
|
+ }
|
|
|
+ return SPV_SUCCESS;
|
|
|
+}
|
|
|
+
|
|
|
+spv_result_t ValidateCooperativeMatrixLoadStoreNV(ValidationState_t& _,
|
|
|
+ const Instruction* inst) {
|
|
|
+ uint32_t type_id;
|
|
|
+ const char* opname;
|
|
|
+ if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
|
|
|
+ type_id = inst->type_id();
|
|
|
+ opname = "SpvOpCooperativeMatrixLoadNV";
|
|
|
+ } else {
|
|
|
+ // get Object operand's type
|
|
|
+ type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
|
|
|
+ opname = "SpvOpCooperativeMatrixStoreNV";
|
|
|
+ }
|
|
|
+
|
|
|
+ auto matrix_type = _.FindDef(type_id);
|
|
|
+
|
|
|
+ if (matrix_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
|
|
|
+ if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
|
|
|
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
|
|
|
+ << "SpvOpCooperativeMatrixLoadNV Result Type <id> '"
|
|
|
+ << _.getIdName(type_id) << "' is not a cooperative matrix type.";
|
|
|
+ } else {
|
|
|
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
|
|
|
+ << "SpvOpCooperativeMatrixStoreNV Object type <id> '"
|
|
|
+ << _.getIdName(type_id) << "' is not a cooperative matrix type.";
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const bool uses_variable_pointers =
|
|
|
+ _.features().variable_pointers ||
|
|
|
+ _.features().variable_pointers_storage_buffer;
|
|
|
+ const auto pointer_index =
|
|
|
+ (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 2u : 0u;
|
|
|
+ const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
|
|
|
+ const auto pointer = _.FindDef(pointer_id);
|
|
|
+ if (!pointer ||
|
|
|
+ ((_.addressing_model() == SpvAddressingModelLogical) &&
|
|
|
+ ((!uses_variable_pointers &&
|
|
|
+ !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
|
|
|
+ (uses_variable_pointers &&
|
|
|
+ !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
|
|
|
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
|
|
|
+ << opname << " Pointer <id> '" << _.getIdName(pointer_id)
|
|
|
+ << "' is not a logical pointer.";
|
|
|
+ }
|
|
|
+
|
|
|
+ const auto pointer_type_id = pointer->type_id();
|
|
|
+ const auto pointer_type = _.FindDef(pointer_type_id);
|
|
|
+ if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
|
|
|
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
|
|
|
+ << opname << " type for pointer <id> '" << _.getIdName(pointer_id)
|
|
|
+ << "' is not a pointer type.";
|
|
|
+ }
|
|
|
+
|
|
|
+ const auto storage_class_index = 1u;
|
|
|
+ const auto storage_class =
|
|
|
+ pointer_type->GetOperandAs<uint32_t>(storage_class_index);
|
|
|
+
|
|
|
+ if (storage_class != SpvStorageClassWorkgroup &&
|
|
|
+ storage_class != SpvStorageClassStorageBuffer &&
|
|
|
+ storage_class != SpvStorageClassPhysicalStorageBufferEXT) {
|
|
|
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
|
|
|
+ << opname << " storage class for pointer type <id> '"
|
|
|
+ << _.getIdName(pointer_type_id)
|
|
|
+ << "' is not Workgroup or StorageBuffer.";
|
|
|
+ }
|
|
|
+
|
|
|
+ const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
|
|
|
+ const auto pointee_type = _.FindDef(pointee_id);
|
|
|
+ if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
|
|
|
+ _.IsFloatScalarOrVectorType(pointee_id))) {
|
|
|
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
|
|
|
+ << opname << " Pointer <id> '" << _.getIdName(pointer->id())
|
|
|
+ << "'s Type must be a scalar or vector type.";
|
|
|
+ }
|
|
|
+
|
|
|
+ const auto stride_index =
|
|
|
+ (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 3u : 2u;
|
|
|
+ const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
|
|
|
+ const auto stride = _.FindDef(stride_id);
|
|
|
+ if (!stride || !_.IsIntScalarType(stride->type_id())) {
|
|
|
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
|
|
|
+ << "Stride operand <id> '" << _.getIdName(stride_id)
|
|
|
+ << "' must be a scalar integer type.";
|
|
|
+ }
|
|
|
+
|
|
|
+ const auto colmajor_index =
|
|
|
+ (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 4u : 3u;
|
|
|
+ const auto colmajor_id = inst->GetOperandAs<uint32_t>(colmajor_index);
|
|
|
+ const auto colmajor = _.FindDef(colmajor_id);
|
|
|
+ if (!colmajor || !_.IsBoolScalarType(colmajor->type_id()) ||
|
|
|
+ !(spvOpcodeIsConstant(colmajor->opcode()) ||
|
|
|
+ spvOpcodeIsSpecConstant(colmajor->opcode()))) {
|
|
|
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
|
|
|
+ << "Column Major operand <id> '" << _.getIdName(colmajor_id)
|
|
|
+ << "' must be a boolean constant instruction.";
|
|
|
+ }
|
|
|
+
|
|
|
+ const auto memory_access_index =
|
|
|
+ (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 5u : 4u;
|
|
|
+ if (inst->operands().size() > memory_access_index) {
|
|
|
+ if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
|
|
|
+ return error;
|
|
|
+ }
|
|
|
+
|
|
|
+ return SPV_SUCCESS;
|
|
|
+}
|
|
|
+
|
|
|
} // namespace
|
|
|
|
|
|
spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
|
|
|
@@ -1164,6 +1354,14 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
|
|
|
case SpvOpArrayLength:
|
|
|
if (auto error = ValidateArrayLength(_, inst)) return error;
|
|
|
break;
|
|
|
+ case SpvOpCooperativeMatrixLoadNV:
|
|
|
+ case SpvOpCooperativeMatrixStoreNV:
|
|
|
+ if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
|
|
|
+ return error;
|
|
|
+ break;
|
|
|
+ case SpvOpCooperativeMatrixLengthNV:
|
|
|
+ if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
|
|
|
+ break;
|
|
|
case SpvOpImageTexelPointer:
|
|
|
case SpvOpGenericPtrMemSemantics:
|
|
|
default:
|