Pārlūkot izejas kodu

Allow using stage functions inside custom shader functions

Yuri Rubinsky 1 gadu atpakaļ
vecāks
revīzija
74c000db17

+ 68 - 38
servers/rendering/shader_language.cpp

@@ -3565,28 +3565,33 @@ bool ShaderLanguage::_validate_function_call(BlockNode *p_block, const FunctionI
 
 	int argcount = args.size();
 
-	if (p_function_info.stage_functions.has(name)) {
-		//stage based function
-		const StageFunctionInfo &sf = p_function_info.stage_functions[name];
-		if (argcount != sf.arguments.size()) {
-			_set_error(vformat(RTR("Invalid number of arguments when calling stage function '%s', which expects %d arguments."), String(name), sf.arguments.size()));
-			return false;
-		}
-		//validate arguments
-		for (int i = 0; i < argcount; i++) {
-			if (args[i] != sf.arguments[i].type) {
-				_set_error(vformat(RTR("Invalid argument type when calling stage function '%s', type expected is '%s'."), String(name), get_datatype_name(sf.arguments[i].type)));
-				return false;
-			}
-		}
+	if (stages) {
+		// Stage functions can be used in custom functions as well, that why need to check them all.
+		for (const KeyValue<StringName, FunctionInfo> &E : *stages) {
+			if (E.value.stage_functions.has(name)) {
+				// Stage-based function.
+				const StageFunctionInfo &sf = E.value.stage_functions[name];
+				if (argcount != sf.arguments.size()) {
+					_set_error(vformat(RTR("Invalid number of arguments when calling stage function '%s', which expects %d arguments."), String(name), sf.arguments.size()));
+					return false;
+				}
+				// Validate arguments.
+				for (int i = 0; i < argcount; i++) {
+					if (args[i] != sf.arguments[i].type) {
+						_set_error(vformat(RTR("Invalid argument type when calling stage function '%s', type expected is '%s'."), String(name), get_datatype_name(sf.arguments[i].type)));
+						return false;
+					}
+				}
 
-		if (r_ret_type) {
-			*r_ret_type = sf.return_type;
-		}
-		if (r_ret_type_str) {
-			*r_ret_type_str = "";
+				if (r_ret_type) {
+					*r_ret_type = sf.return_type;
+				}
+				if (r_ret_type_str) {
+					*r_ret_type_str = "";
+				}
+				return true;
+			}
 		}
-		return true;
 	}
 
 	bool failed_builtin = false;
@@ -5937,22 +5942,35 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons
 							calls_info[current_function].calls.push_back(&calls_info[name]);
 						}
 
-						int idx = 0;
 						bool is_builtin = false;
 
-						while (frag_only_func_defs[idx].name) {
-							if (frag_only_func_defs[idx].name == name) {
-								// If a built-in function not found for the current shader type, then it shouldn't be parsed further.
-								if (!is_supported_frag_only_funcs) {
-									_set_error(vformat(RTR("Built-in function '%s' is not supported for the '%s' shader type."), name, shader_type_identifier));
-									return nullptr;
+						if (is_supported_frag_only_funcs && stages) {
+							for (const KeyValue<StringName, FunctionInfo> &E : *stages) {
+								if (E.value.stage_functions.has(name)) {
+									// Register usage of the restricted stage function.
+									calls_info[current_function].uses_restricted_items.push_back(Pair<StringName, CallInfo::Item>(name, CallInfo::Item(CallInfo::Item::ITEM_TYPE_BUILTIN, _get_tkpos())));
+									is_builtin = true;
+									break;
 								}
-								// Register usage of the restricted function.
-								calls_info[current_function].uses_restricted_items.push_back(Pair<StringName, CallInfo::Item>(name, CallInfo::Item(CallInfo::Item::ITEM_TYPE_BUILTIN, _get_tkpos())));
-								is_builtin = true;
-								break;
 							}
-							idx++;
+						}
+
+						if (!is_builtin) {
+							int idx = 0;
+							while (frag_only_func_defs[idx].name) {
+								if (frag_only_func_defs[idx].name == name) {
+									// If a built-in function not found for the current shader type, then it shouldn't be parsed further.
+									if (!is_supported_frag_only_funcs) {
+										_set_error(vformat(RTR("Built-in function '%s' is not supported for the '%s' shader type."), name, shader_type_identifier));
+										return nullptr;
+									}
+									// Register usage of the restricted function.
+									calls_info[current_function].uses_restricted_items.push_back(Pair<StringName, CallInfo::Item>(name, CallInfo::Item(CallInfo::Item::ITEM_TYPE_BUILTIN, _get_tkpos())));
+									is_builtin = true;
+									break;
+								}
+								idx++;
+							}
 						}
 
 						// Recursively checks for the restricted function call.
@@ -11160,9 +11178,15 @@ Error ShaderLanguage::complete(const String &p_code, const ShaderCompileInfo &p_
 				int idx = 0;
 				bool low_end = RenderingServer::get_singleton()->is_low_end();
 
-				if (stages && stages->has(skip_function)) {
-					for (const KeyValue<StringName, StageFunctionInfo> &E : (*stages)[skip_function].stage_functions) {
-						matches.insert(String(E.key), ScriptLanguage::CODE_COMPLETION_KIND_FUNCTION);
+				if (stages) {
+					// Stage functions can be used in custom functions as well, that why need to check them all.
+					for (const KeyValue<StringName, FunctionInfo> &E : *stages) {
+						for (const KeyValue<StringName, StageFunctionInfo> &F : E.value.stage_functions) {
+							if (F.value.skip_function == skip_function && stages->has(skip_function)) {
+								continue;
+							}
+							matches.insert(String(F.key), ScriptLanguage::CODE_COMPLETION_KIND_FUNCTION);
+						}
 					}
 				}
 
@@ -11292,9 +11316,15 @@ Error ShaderLanguage::complete(const String &p_code, const ShaderCompileInfo &p_
 				return OK;
 			}
 
-			if (stages && stages->has(block_function)) {
-				for (const KeyValue<StringName, StageFunctionInfo> &E : (*stages)[block_function].stage_functions) {
-					if (completion_function == E.key) {
+			if (stages) {
+				// Stage functions can be used in custom functions as well, that why need to check them all.
+				for (const KeyValue<StringName, FunctionInfo> &S : *stages) {
+					for (const KeyValue<StringName, StageFunctionInfo> &E : S.value.stage_functions) {
+						// No need to check for the skip function here.
+						if (completion_function != E.key) {
+							continue;
+						}
+
 						calltip += get_datatype_name(E.value.return_type);
 						calltip += " ";
 						calltip += E.key;

+ 1 - 0
servers/rendering/shader_language.h

@@ -859,6 +859,7 @@ public:
 
 		Vector<Argument> arguments;
 		DataType return_type = TYPE_VOID;
+		String skip_function;
 	};
 
 	struct ModeInfo {

+ 2 - 0
servers/rendering/shader_types.cpp

@@ -284,6 +284,7 @@ ShaderTypes::ShaderTypes() {
 
 	{
 		ShaderLanguage::StageFunctionInfo func;
+		func.skip_function = "vertex";
 		func.arguments.push_back(ShaderLanguage::StageFunctionInfo::Argument("sdf_pos", ShaderLanguage::TYPE_VEC2));
 		func.return_type = ShaderLanguage::TYPE_FLOAT; //whether it could emit
 		shader_modes[RS::SHADER_CANVAS_ITEM].functions["fragment"].stage_functions["texture_sdf"] = func;
@@ -297,6 +298,7 @@ ShaderTypes::ShaderTypes() {
 
 	{
 		ShaderLanguage::StageFunctionInfo func;
+		func.skip_function = "vertex";
 		func.arguments.push_back(ShaderLanguage::StageFunctionInfo::Argument("uv", ShaderLanguage::TYPE_VEC2));
 		func.return_type = ShaderLanguage::TYPE_VEC2; //whether it could emit
 		shader_modes[RS::SHADER_CANVAS_ITEM].functions["fragment"].stage_functions["screen_uv_to_sdf"] = func;