Pārlūkot izejas kodu

Allow usage of `discard` inside custom shader functions

Yuri Rubinsky 1 gadu atpakaļ
vecāks
revīzija
ad7e7a51b2

+ 16 - 7
servers/rendering/shader_language.cpp

@@ -356,7 +356,7 @@ const ShaderLanguage::KeyWord ShaderLanguage::keyword_list[] = {
 	{ TK_CF_BREAK, "break", CF_BLOCK, {}, {} },
 	{ TK_CF_CONTINUE, "continue", CF_BLOCK, {}, {} },
 	{ TK_CF_RETURN, "return", CF_BLOCK, {}, {} },
-	{ TK_CF_DISCARD, "discard", CF_BLOCK, { "particles", "sky", "fog" }, { "fragment" } },
+	{ TK_CF_DISCARD, "discard", CF_BLOCK, { "particles", "sky", "fog" }, { "vertex" } },
 
 	// function specifier keywords
 
@@ -8581,6 +8581,11 @@ Error ShaderLanguage::_parse_block(BlockNode *p_block, const FunctionInfo &p_fun
 				block = block->parent_block;
 			}
 		} else if (tk.type == TK_CF_DISCARD) {
+			if (!is_discard_supported) {
+				_set_error(vformat(RTR("Use of '%s' is not supported for the '%s' shader type."), "discard", shader_type_identifier));
+				return ERR_PARSE_ERROR;
+			}
+
 			//check return type
 			BlockNode *b = p_block;
 			while (b && !b->parent_function) {
@@ -8592,7 +8597,7 @@ Error ShaderLanguage::_parse_block(BlockNode *p_block, const FunctionInfo &p_fun
 			}
 
 			if (!b->parent_function->can_discard) {
-				_set_error(vformat(RTR("Use of '%s' is not allowed here."), "discard"));
+				_set_error(vformat(RTR("'%s' cannot be used within the '%s' processor function."), "discard", b->parent_function->name));
 				return ERR_PARSE_ERROR;
 			}
 
@@ -8601,6 +8606,9 @@ Error ShaderLanguage::_parse_block(BlockNode *p_block, const FunctionInfo &p_fun
 
 			pos = _get_tkpos();
 			tk = _get_token();
+
+			calls_info[b->parent_function->name].uses_restricted_items.push_back(Pair<StringName, CallInfo::Item>("discard", CallInfo::Item(CallInfo::Item::ITEM_TYPE_BUILTIN, pos)));
+
 			if (tk.type != TK_SEMICOLON) {
 				_set_expected_after_error(";", "discard");
 				return ERR_PARSE_ERROR;
@@ -8838,7 +8846,9 @@ Error ShaderLanguage::_parse_shader(const HashMap<StringName, FunctionInfo> &p_f
 	ShaderNode::Uniform::Scope uniform_scope = ShaderNode::Uniform::SCOPE_LOCAL;
 
 	stages = &p_functions;
-	is_supported_frag_only_funcs = shader_type_identifier == "canvas_item" || shader_type_identifier == "spatial" || shader_type_identifier == "sky";
+
+	is_discard_supported = shader_type_identifier == "canvas_item" || shader_type_identifier == "spatial";
+	is_supported_frag_only_funcs = is_discard_supported || shader_type_identifier == "sky";
 
 	const FunctionInfo &constants = p_functions.has("constants") ? p_functions["constants"] : FunctionInfo();
 
@@ -10332,6 +10342,8 @@ Error ShaderLanguage::_parse_shader(const HashMap<StringName, FunctionInfo> &p_f
 
 				if (p_functions.has(name)) {
 					func_node->can_discard = p_functions[name].can_discard;
+				} else {
+					func_node->can_discard = is_discard_supported; // Allow use it for custom functions (in supported shader types).
 				}
 
 				if (!function_overload_count.has(name)) {
@@ -10922,10 +10934,7 @@ Error ShaderLanguage::complete(const String &p_code, const ShaderCompileInfo &p_
 				break; // Ignore hint keywords (parsed below).
 			}
 			if (keyword_list[i].flags & keyword_completion_context) {
-				if (keyword_list[i].excluded_shader_types.has(shader_type_identifier)) {
-					continue;
-				}
-				if (!keyword_list[i].functions.is_empty() && !keyword_list[i].functions.has(current_function)) {
+				if (keyword_list[i].excluded_shader_types.has(shader_type_identifier) || keyword_list[i].excluded_functions.has(current_function)) {
 					continue;
 				}
 				ScriptLanguage::CodeCompletionOption option(keyword_list[i].text, ScriptLanguage::CODE_COMPLETION_KIND_PLAIN_TEXT);

+ 2 - 1
servers/rendering/shader_language.h

@@ -934,7 +934,7 @@ private:
 		const char *text;
 		uint32_t flags;
 		const Vector<String> excluded_shader_types;
-		const Vector<String> functions;
+		const Vector<String> excluded_functions;
 	};
 
 	static const KeyWord keyword_list[];
@@ -1150,6 +1150,7 @@ private:
 
 	const HashMap<StringName, FunctionInfo> *stages = nullptr;
 	bool is_supported_frag_only_funcs = false;
+	bool is_discard_supported = false;
 
 	bool _get_completable_identifier(BlockNode *p_block, CompletionType p_type, StringName &identifier);
 	static const BuiltinFuncDef builtin_func_defs[];