Переглянути джерело

Merge pull request #98212 from stuartcarnie/sgc/metal_improvements

Metal: Performance improvements and bug fixes
Thaddeus Crews 10 місяців тому
батько
коміт
4630cbc487

+ 38 - 15
drivers/metal/metal_objects.h

@@ -96,6 +96,22 @@ _FORCE_INLINE_ ShaderStageUsage &operator|=(ShaderStageUsage &p_a, int p_b) {
 	return p_a;
 }
 
+enum StageResourceUsage : uint32_t {
+	VertexRead = (MTLResourceUsageRead << RDD::SHADER_STAGE_VERTEX * 2),
+	VertexWrite = (MTLResourceUsageWrite << RDD::SHADER_STAGE_VERTEX * 2),
+	FragmentRead = (MTLResourceUsageRead << RDD::SHADER_STAGE_FRAGMENT * 2),
+	FragmentWrite = (MTLResourceUsageWrite << RDD::SHADER_STAGE_FRAGMENT * 2),
+	TesselationControlRead = (MTLResourceUsageRead << RDD::SHADER_STAGE_TESSELATION_CONTROL * 2),
+	TesselationControlWrite = (MTLResourceUsageWrite << RDD::SHADER_STAGE_TESSELATION_CONTROL * 2),
+	TesselationEvaluationRead = (MTLResourceUsageRead << RDD::SHADER_STAGE_TESSELATION_EVALUATION * 2),
+	TesselationEvaluationWrite = (MTLResourceUsageWrite << RDD::SHADER_STAGE_TESSELATION_EVALUATION * 2),
+	ComputeRead = (MTLResourceUsageRead << RDD::SHADER_STAGE_COMPUTE * 2),
+	ComputeWrite = (MTLResourceUsageWrite << RDD::SHADER_STAGE_COMPUTE * 2),
+};
+
+typedef LocalVector<__unsafe_unretained id<MTLResource>> ResourceVector;
+typedef HashMap<StageResourceUsage, ResourceVector> ResourceUsageMap;
+
 enum class MDCommandBufferStateType {
 	None,
 	Render,
@@ -230,6 +246,7 @@ public:
 		uint32_t index_offset = 0;
 		LocalVector<id<MTLBuffer> __unsafe_unretained> vertex_buffers;
 		LocalVector<NSUInteger> vertex_offsets;
+		ResourceUsageMap resource_usage;
 		// clang-format off
 		enum DirtyFlag: uint8_t {
 			DIRTY_NONE     = 0b0000'0000,
@@ -271,8 +288,14 @@ public:
 			blend_constants.reset();
 			vertex_buffers.clear();
 			vertex_offsets.clear();
+			// Keep the keys, as they are likely to be used again.
+			for (KeyValue<StageResourceUsage, LocalVector<__unsafe_unretained id<MTLResource>>> &kv : resource_usage) {
+				kv.value.clear();
+			}
 		}
 
+		void end_encoding();
+
 		_FORCE_INLINE_ void mark_viewport_dirty() {
 			if (viewports.is_empty()) {
 				return;
@@ -356,13 +379,20 @@ public:
 	} render;
 
 	// State specific for a compute pass.
-	struct {
+	struct ComputeState {
 		MDComputePipeline *pipeline = nullptr;
 		id<MTLComputeCommandEncoder> encoder = nil;
+		ResourceUsageMap resource_usage;
 		_FORCE_INLINE_ void reset() {
 			pipeline = nil;
 			encoder = nil;
+			// Keep the keys, as they are likely to be used again.
+			for (KeyValue<StageResourceUsage, LocalVector<__unsafe_unretained id<MTLResource>>> &kv : resource_usage) {
+				kv.value.clear();
+			}
 		}
+
+		void end_encoding();
 	} compute;
 
 	// State specific to a blit pass.
@@ -632,19 +662,6 @@ public:
 	MDRenderShader(CharString p_name, Vector<UniformSet> p_sets, MDLibrary *p_vert, MDLibrary *p_frag);
 };
 
-enum StageResourceUsage : uint32_t {
-	VertexRead = (MTLResourceUsageRead << RDD::SHADER_STAGE_VERTEX * 2),
-	VertexWrite = (MTLResourceUsageWrite << RDD::SHADER_STAGE_VERTEX * 2),
-	FragmentRead = (MTLResourceUsageRead << RDD::SHADER_STAGE_FRAGMENT * 2),
-	FragmentWrite = (MTLResourceUsageWrite << RDD::SHADER_STAGE_FRAGMENT * 2),
-	TesselationControlRead = (MTLResourceUsageRead << RDD::SHADER_STAGE_TESSELATION_CONTROL * 2),
-	TesselationControlWrite = (MTLResourceUsageWrite << RDD::SHADER_STAGE_TESSELATION_CONTROL * 2),
-	TesselationEvaluationRead = (MTLResourceUsageRead << RDD::SHADER_STAGE_TESSELATION_EVALUATION * 2),
-	TesselationEvaluationWrite = (MTLResourceUsageWrite << RDD::SHADER_STAGE_TESSELATION_EVALUATION * 2),
-	ComputeRead = (MTLResourceUsageRead << RDD::SHADER_STAGE_COMPUTE * 2),
-	ComputeWrite = (MTLResourceUsageWrite << RDD::SHADER_STAGE_COMPUTE * 2),
-};
-
 _FORCE_INLINE_ StageResourceUsage &operator|=(StageResourceUsage &p_a, uint32_t p_b) {
 	p_a = StageResourceUsage(uint32_t(p_a) | p_b);
 	return p_a;
@@ -667,7 +684,13 @@ struct HashMapComparatorDefault<RDD::ShaderID> {
 
 struct BoundUniformSet {
 	id<MTLBuffer> buffer;
-	HashMap<id<MTLResource>, StageResourceUsage> bound_resources;
+	ResourceUsageMap usage_to_resources;
+
+	/// Perform a 2-way merge each key of `ResourceVector` resources from this set into the
+	/// destination set.
+	///
+	/// Assumes the vectors of resources are sorted.
+	void merge_into(ResourceUsageMap &p_dst) const;
 };
 
 class API_AVAILABLE(macos(11.0), ios(14.0)) MDUniformSet {

+ 103 - 27
drivers/metal/metal_objects.mm

@@ -58,7 +58,7 @@
 
 void MDCommandBuffer::begin() {
 	DEV_ASSERT(commandBuffer == nil);
-	commandBuffer = queue.commandBuffer;
+	commandBuffer = queue.commandBufferWithUnretainedReferences;
 }
 
 void MDCommandBuffer::end() {
@@ -390,6 +390,38 @@ void MDCommandBuffer::render_set_blend_constants(const Color &p_constants) {
 	}
 }
 
+void BoundUniformSet::merge_into(ResourceUsageMap &p_dst) const {
+	for (KeyValue<StageResourceUsage, ResourceVector> const &keyval : usage_to_resources) {
+		ResourceVector *resources = p_dst.getptr(keyval.key);
+		if (resources == nullptr) {
+			resources = &p_dst.insert(keyval.key, ResourceVector())->value;
+		}
+		// Reserve space for the new resources, assuming they are all added.
+		resources->reserve(resources->size() + keyval.value.size());
+
+		uint32_t i = 0, j = 0;
+		__unsafe_unretained id<MTLResource> *resources_ptr = resources->ptr();
+		const __unsafe_unretained id<MTLResource> *keyval_ptr = keyval.value.ptr();
+		// 2-way merge.
+		while (i < resources->size() && j < keyval.value.size()) {
+			if (resources_ptr[i] < keyval_ptr[j]) {
+				i++;
+			} else if (resources_ptr[i] > keyval_ptr[j]) {
+				resources->insert(i, keyval_ptr[j]);
+				i++;
+				j++;
+			} else {
+				i++;
+				j++;
+			}
+		}
+		// Append the remaining resources.
+		for (; j < keyval.value.size(); j++) {
+			resources->push_back(keyval_ptr[j]);
+		}
+	}
+}
+
 void MDCommandBuffer::_render_bind_uniform_sets() {
 	DEV_ASSERT(type == MDCommandBufferStateType::Render);
 	if (!render.dirty.has_flag(RenderState::DIRTY_UNIFORMS)) {
@@ -408,7 +440,7 @@ void MDCommandBuffer::_render_bind_uniform_sets() {
 		// Find the index of the next set bit.
 		int index = __builtin_ctzll(set_uniforms);
 		// Clear the set bit.
-		set_uniforms &= ~(1ULL << index);
+		set_uniforms &= (set_uniforms - 1);
 		MDUniformSet *set = render.uniform_sets[index];
 		if (set == nullptr || set->index >= (uint32_t)shader->sets.size()) {
 			continue;
@@ -416,17 +448,7 @@ void MDCommandBuffer::_render_bind_uniform_sets() {
 		UniformSet const &set_info = shader->sets[set->index];
 
 		BoundUniformSet &bus = set->boundUniformSetForShader(shader, device);
-
-		for (KeyValue<id<MTLResource>, StageResourceUsage> const &keyval : bus.bound_resources) {
-			MTLResourceUsage usage = resource_usage_for_stage(keyval.value, RDD::ShaderStage::SHADER_STAGE_VERTEX);
-			if (usage != 0) {
-				[enc useResource:keyval.key usage:usage stages:MTLRenderStageVertex];
-			}
-			usage = resource_usage_for_stage(keyval.value, RDD::ShaderStage::SHADER_STAGE_FRAGMENT);
-			if (usage != 0) {
-				[enc useResource:keyval.key usage:usage stages:MTLRenderStageFragment];
-			}
-		}
+		bus.merge_into(render.resource_usage);
 
 		// Set the buffer for the vertex stage.
 		{
@@ -545,8 +567,7 @@ void MDCommandBuffer::_end_render_pass() {
 		// see: https://github.com/KhronosGroup/MoltenVK/blob/d20d13fe2735adb845636a81522df1b9d89c0fba/MoltenVK/MoltenVK/GPUObjects/MVKRenderPass.mm#L407
 	}
 
-	[render.encoder endEncoding];
-	render.encoder = nil;
+	render.end_encoding();
 }
 
 void MDCommandBuffer::_render_clear_render_area() {
@@ -792,10 +813,59 @@ void MDCommandBuffer::render_draw_indirect_count(RDD::BufferID p_indirect_buffer
 	ERR_FAIL_MSG("not implemented");
 }
 
+void MDCommandBuffer::RenderState::end_encoding() {
+	if (encoder == nil) {
+		return;
+	}
+
+	// Bind all resources.
+	for (KeyValue<StageResourceUsage, ResourceVector> const &keyval : resource_usage) {
+		if (keyval.value.is_empty()) {
+			continue;
+		}
+
+		MTLResourceUsage vert_usage = resource_usage_for_stage(keyval.key, RDD::ShaderStage::SHADER_STAGE_VERTEX);
+		MTLResourceUsage frag_usage = resource_usage_for_stage(keyval.key, RDD::ShaderStage::SHADER_STAGE_FRAGMENT);
+		if (vert_usage == frag_usage) {
+			[encoder useResources:keyval.value.ptr() count:keyval.value.size() usage:vert_usage stages:MTLRenderStageVertex | MTLRenderStageFragment];
+		} else {
+			if (vert_usage != 0) {
+				[encoder useResources:keyval.value.ptr() count:keyval.value.size() usage:vert_usage stages:MTLRenderStageVertex];
+			}
+			if (frag_usage != 0) {
+				[encoder useResources:keyval.value.ptr() count:keyval.value.size() usage:frag_usage stages:MTLRenderStageFragment];
+			}
+		}
+	}
+
+	[encoder endEncoding];
+	encoder = nil;
+}
+
+void MDCommandBuffer::ComputeState::end_encoding() {
+	if (encoder == nil) {
+		return;
+	}
+
+	// Bind all resources.
+	for (KeyValue<StageResourceUsage, ResourceVector> const &keyval : resource_usage) {
+		if (keyval.value.is_empty()) {
+			continue;
+		}
+		MTLResourceUsage usage = resource_usage_for_stage(keyval.key, RDD::ShaderStage::SHADER_STAGE_COMPUTE);
+		if (usage != 0) {
+			[encoder useResources:keyval.value.ptr() count:keyval.value.size() usage:usage];
+		}
+	}
+
+	[encoder endEncoding];
+	encoder = nil;
+}
+
 void MDCommandBuffer::render_end_pass() {
 	DEV_ASSERT(type == MDCommandBufferStateType::Render);
 
-	[render.encoder endEncoding];
+	render.end_encoding();
 	render.reset();
 	type = MDCommandBufferStateType::None;
 }
@@ -813,13 +883,7 @@ void MDCommandBuffer::compute_bind_uniform_set(RDD::UniformSetID p_uniform_set,
 
 	MDUniformSet *set = (MDUniformSet *)(p_uniform_set.id);
 	BoundUniformSet &bus = set->boundUniformSetForShader(shader, device);
-
-	for (KeyValue<id<MTLResource>, StageResourceUsage> &keyval : bus.bound_resources) {
-		MTLResourceUsage usage = resource_usage_for_stage(keyval.value, RDD::ShaderStage::SHADER_STAGE_COMPUTE);
-		if (usage != 0) {
-			[enc useResource:keyval.key usage:usage];
-		}
-	}
+	bus.merge_into(compute.resource_usage);
 
 	uint32_t const *offset = set_info.offsets.getptr(RDD::SHADER_STAGE_COMPUTE);
 	if (offset) {
@@ -848,7 +912,7 @@ void MDCommandBuffer::compute_dispatch_indirect(RDD::BufferID p_indirect_buffer,
 void MDCommandBuffer::_end_compute_dispatch() {
 	DEV_ASSERT(type == MDCommandBufferStateType::Compute);
 
-	[compute.encoder endEncoding];
+	compute.end_encoding();
 	compute.reset();
 	type = MDCommandBufferStateType::None;
 }
@@ -1052,7 +1116,20 @@ BoundUniformSet &MDUniformSet::boundUniformSetForShader(MDShader *p_shader, id<M
 		}
 	}
 
-	BoundUniformSet bs = { .buffer = enc_buffer, .bound_resources = bound_resources };
+	SearchArray<__unsafe_unretained id<MTLResource>> search;
+	ResourceUsageMap usage_to_resources;
+	for (KeyValue<id<MTLResource>, StageResourceUsage> const &keyval : bound_resources) {
+		ResourceVector *resources = usage_to_resources.getptr(keyval.value);
+		if (resources == nullptr) {
+			resources = &usage_to_resources.insert(keyval.value, ResourceVector())->value;
+		}
+		int64_t pos = search.bisect(resources->ptr(), resources->size(), keyval.key, true);
+		if (pos == resources->size() || (*resources)[pos] != keyval.key) {
+			resources->insert(pos, keyval.key);
+		}
+	}
+
+	BoundUniformSet bs = { .buffer = enc_buffer, .usage_to_resources = usage_to_resources };
 	bound_uniforms.insert(p_shader, bs);
 	return bound_uniforms.get(p_shader);
 }
@@ -1211,8 +1288,7 @@ vertex VaryingsPos vertClear(AttributesPos attributes [[stage_in]], constant Cle
     varyings.layer = uint(attributes.a_position.w);
     return varyings;
 }
-)",
-				ClearAttKey::DEPTH_INDEX];
+)", ClearAttKey::DEPTH_INDEX];
 
 		return new_func(msl, @"vertClear", nil);
 	}

+ 5 - 1
drivers/metal/rendering_device_driver_metal.mm

@@ -2060,6 +2060,10 @@ Vector<uint8_t> RenderingDeviceDriverMetal::shader_compile_binary_from_spirv(Vec
 
 					case BT::Sampler: {
 						primary.dataType = MTLDataTypeSampler;
+						primary.arrayLength = 1;
+						for (uint32_t const &a : a_type.array) {
+							primary.arrayLength *= a;
+						}
 					} break;
 
 					default: {
@@ -2067,7 +2071,7 @@ Vector<uint8_t> RenderingDeviceDriverMetal::shader_compile_binary_from_spirv(Vec
 					} break;
 				}
 
-				// Find array length.
+				// Find array length of image.
 				if (basetype == BT::Image || basetype == BT::SampledImage) {
 					primary.arrayLength = 1;
 					for (uint32_t const &a : a_type.array) {