Преглед изворни кода

Merge pull request #92441 from Chaosus/shader_custom_func_overloads

Implement custom function overloading in shading language
Thaddeus Crews пре 10 месеци
родитељ
комит
af77100e39

+ 2 - 2
servers/rendering/shader_compiler.cpp

@@ -355,7 +355,7 @@ void ShaderCompiler::_dump_function_deps(const SL::ShaderNode *p_node, const Str
 		}
 
 		header += " ";
-		header += _mkid(fnode->name);
+		header += _mkid(fnode->rname);
 		header += "(";
 
 		for (int i = 0; i < fnode->arguments.size(); i++) {
@@ -1190,7 +1190,7 @@ String ShaderCompiler::_dump_node_code(const SL::Node *p_node, int p_level, Gene
 						} else if (p_default_actions.renames.has(vnode->name)) {
 							code += p_default_actions.renames[vnode->name];
 						} else {
-							code += _mkid(vnode->name);
+							code += _mkid(vnode->rname);
 						}
 					}
 

+ 238 - 126
servers/rendering/shader_language.cpp

@@ -1305,6 +1305,7 @@ void ShaderLanguage::clear() {
 
 	include_markers_handled.clear();
 	calls_info.clear();
+	function_overload_count.clear();
 
 #ifdef DEBUG_ENABLED
 	keyword_completion_context = CF_UNSPECIFIED;
@@ -3554,6 +3555,7 @@ bool ShaderLanguage::_validate_function_call(BlockNode *p_block, const FunctionI
 	ERR_FAIL_COND_V(p_func->arguments[0]->type != Node::NODE_TYPE_VARIABLE, false);
 
 	StringName name = static_cast<VariableNode *>(p_func->arguments[0])->name.operator String();
+	StringName rname = static_cast<VariableNode *>(p_func->arguments[0])->rname.operator String();
 
 	for (int i = 1; i < p_func->arguments.size(); i++) {
 		args.push_back(p_func->arguments[i]->get_datatype());
@@ -3891,9 +3893,10 @@ bool ShaderLanguage::_validate_function_call(BlockNode *p_block, const FunctionI
 		}
 
 		bool fail = false;
+		bool use_constant_conversion = function_overload_count[rname] == 0;
 
 		for (int j = 0; j < args.size(); j++) {
-			if (get_scalar_type(args[j]) == args[j] && p_func->arguments[j + 1]->type == Node::NODE_TYPE_CONSTANT && args3[j] == 0 && convert_constant(static_cast<ConstantNode *>(p_func->arguments[j + 1]), pfunc->arguments[j].type)) {
+			if (use_constant_conversion && get_scalar_type(args[j]) == args[j] && p_func->arguments[j + 1]->type == Node::NODE_TYPE_CONSTANT && args3[j] == 0 && convert_constant(static_cast<ConstantNode *>(p_func->arguments[j + 1]), pfunc->arguments[j].type)) {
 				//all good, but it needs implicit conversion later
 			} else if (args[j] != pfunc->arguments[j].type || (args[j] == TYPE_STRUCT && args2[j] != pfunc->arguments[j].struct_name) || args3[j] != pfunc->arguments[j].array_size) {
 				String func_arg_name;
@@ -3919,7 +3922,7 @@ bool ShaderLanguage::_validate_function_call(BlockNode *p_block, const FunctionI
 					arg_name += "]";
 				}
 
-				_set_error(vformat(RTR("Invalid argument for \"%s(%s)\" function: argument %d should be %s but is %s."), String(name), arg_list, j + 1, func_arg_name, arg_name));
+				_set_error(vformat(RTR("Invalid argument for \"%s(%s)\" function: argument %d should be %s but is %s."), String(rname), arg_list, j + 1, func_arg_name, arg_name));
 				fail = true;
 				break;
 			}
@@ -3960,9 +3963,9 @@ bool ShaderLanguage::_validate_function_call(BlockNode *p_block, const FunctionI
 
 	if (exists) {
 		if (last_arg_count > args.size()) {
-			_set_error(vformat(RTR("Too few arguments for \"%s(%s)\" call. Expected at least %d but received %d."), String(name), arg_list, last_arg_count, args.size()));
+			_set_error(vformat(RTR("Too few arguments for \"%s(%s)\" call. Expected at least %d but received %d."), String(rname), arg_list, last_arg_count, args.size()));
 		} else if (last_arg_count < args.size()) {
-			_set_error(vformat(RTR("Too many arguments for \"%s(%s)\" call. Expected at most %d but received %d."), String(name), arg_list, last_arg_count, args.size()));
+			_set_error(vformat(RTR("Too many arguments for \"%s(%s)\" call. Expected at most %d but received %d."), String(rname), arg_list, last_arg_count, args.size()));
 		}
 	}
 
@@ -5822,42 +5825,16 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons
 						}
 					}
 
-					const StringName &name = identifier;
-
-					if (name != current_function) { // Recursion is not allowed.
-						// Register call.
-						if (calls_info.has(name)) {
-							calls_info[current_function].calls.push_back(&calls_info[name]);
-						}
-
-						int idx = 0;
-						bool is_builtin = false;
-
-						while (frag_only_func_defs[idx].name) {
-							if (frag_only_func_defs[idx].name == name) {
-								// If a built-in function not found for the current shader type, then it shouldn't be parsed further.
-								if (!is_supported_frag_only_funcs) {
-									_set_error(vformat(RTR("Built-in function '%s' is not supported for the '%s' shader type."), name, shader_type_identifier));
-									return nullptr;
-								}
-								// Register usage of the restricted function.
-								calls_info[current_function].uses_restricted_items.push_back(Pair<StringName, CallInfo::Item>(name, CallInfo::Item(CallInfo::Item::ITEM_TYPE_BUILTIN, _get_tkpos())));
-								is_builtin = true;
-								break;
-							}
-							idx++;
-						}
-
-						// Recursively checks for the restricted function call.
-						if (is_supported_frag_only_funcs && current_function == "vertex" && stages->has(current_function) && !_validate_restricted_func(name, &calls_info[current_function], is_builtin)) {
-							return nullptr;
-						}
-					}
+					const StringName &rname = identifier;
+					StringName name = identifier;
 
 					OperatorNode *func = alloc_node<OperatorNode>();
 					func->op = OP_CALL;
+
 					VariableNode *funcname = alloc_node<VariableNode>();
 					funcname->name = name;
+					funcname->rname = name;
+
 					func->arguments.push_back(funcname);
 
 					int carg = -1;
@@ -5874,22 +5851,72 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons
 						bnode = bnode->parent_block;
 					}
 
-					//test if function was parsed first
+					// Test if function was parsed first.
 					int function_index = -1;
-					for (int i = 0; i < shader->vfunctions.size(); i++) {
-						if (shader->vfunctions[i].name == name) {
-							//add to current function as dependency
-							for (int j = 0; j < shader->vfunctions.size(); j++) {
-								if (shader->vfunctions[j].name == current_function) {
-									shader->vfunctions.write[j].uses_function.insert(name);
-									break;
+					for (int i = 0, max_valid_args = 0; i < shader->vfunctions.size(); i++) {
+						if (!shader->vfunctions[i].callable || shader->vfunctions[i].rname != rname) {
+							continue;
+						}
+
+						bool found = true;
+						int valid_args = 0;
+
+						// Search for correct overload.
+						for (int j = 1; j < func->arguments.size(); j++) {
+							if (j - 1 == shader->vfunctions[i].function->arguments.size()) {
+								found = false;
+								break;
+							}
+
+							const FunctionNode::Argument &a = shader->vfunctions[i].function->arguments[j - 1];
+							Node *b = func->arguments[j];
+
+							if (a.type == b->get_datatype() && a.array_size == b->get_array_size()) {
+								if (a.type == TYPE_STRUCT) {
+									if (a.struct_name != b->get_datatype_name()) {
+										found = false;
+										break;
+									} else {
+										valid_args++;
+									}
+								} else {
+									valid_args++;
 								}
+							} else {
+								if (function_overload_count[rname] == 0 && get_scalar_type(a.type) == a.type && b->type == Node::NODE_TYPE_CONSTANT && a.array_size == 0 && convert_constant(static_cast<ConstantNode *>(b), a.type)) {
+									// Implicit cast if no overloads.
+									continue;
+								}
+								found = false;
+								break;
 							}
+						}
 
-							//see if texture arguments must connect
-							function_index = i;
-							break;
+						// Using the best match index for completion hint if the function not found.
+						if (valid_args > max_valid_args) {
+							name = shader->vfunctions[i].name;
+							funcname->name = name;
+							max_valid_args = valid_args;
+						}
+
+						if (!found) {
+							continue;
 						}
+
+						// Add to current function as dependency.
+						for (int j = 0; j < shader->vfunctions.size(); j++) {
+							if (shader->vfunctions[j].name == current_function) {
+								shader->vfunctions.write[j].uses_function.insert(name);
+								break;
+							}
+						}
+
+						name = shader->vfunctions[i].name;
+						funcname->name = name;
+
+						// See if texture arguments must connect.
+						function_index = i;
+						break;
 					}
 
 					if (carg >= 0) {
@@ -5904,9 +5931,39 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons
 						return nullptr;
 					}
 
+					if (name != current_function) { // Recursion is not allowed.
+						// Register call.
+						if (calls_info.has(name)) {
+							calls_info[current_function].calls.push_back(&calls_info[name]);
+						}
+
+						int idx = 0;
+						bool is_builtin = false;
+
+						while (frag_only_func_defs[idx].name) {
+							if (frag_only_func_defs[idx].name == name) {
+								// If a built-in function not found for the current shader type, then it shouldn't be parsed further.
+								if (!is_supported_frag_only_funcs) {
+									_set_error(vformat(RTR("Built-in function '%s' is not supported for the '%s' shader type."), name, shader_type_identifier));
+									return nullptr;
+								}
+								// Register usage of the restricted function.
+								calls_info[current_function].uses_restricted_items.push_back(Pair<StringName, CallInfo::Item>(name, CallInfo::Item(CallInfo::Item::ITEM_TYPE_BUILTIN, _get_tkpos())));
+								is_builtin = true;
+								break;
+							}
+							idx++;
+						}
+
+						// Recursively checks for the restricted function call.
+						if (is_supported_frag_only_funcs && current_function == "vertex" && stages->has(current_function) && !_validate_restricted_func(name, &calls_info[current_function], is_builtin)) {
+							return nullptr;
+						}
+					}
+
 					bool is_custom_func = false;
 					if (!_validate_function_call(p_block, p_function_info, func, &func->return_cache, &func->struct_name, &is_custom_func)) {
-						_set_error(vformat(RTR("No matching function found for: '%s'."), String(funcname->name)));
+						_set_error(vformat(RTR("No matching function found for: '%s'."), String(funcname->rname)));
 						return nullptr;
 					}
 					completion_class = TAG_GLOBAL; // reset sub-class
@@ -6095,7 +6152,7 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons
 					}
 					expr = func;
 #ifdef DEBUG_ENABLED
-					if (check_warnings) {
+					if (check_warnings && is_custom_func) {
 						StringName func_name;
 
 						if (p_block && p_block->parent_function) {
@@ -9932,7 +9989,8 @@ Error ShaderLanguage::_parse_shader(const HashMap<StringName, FunctionInfo> &p_f
 					return ERR_PARSE_ERROR;
 				}
 
-				if (shader->structs.has(name) || _find_identifier(nullptr, false, constants, name) || has_builtin(p_functions, name, !is_constant)) {
+				IdentifierType itype;
+				if (shader->structs.has(name) || (_find_identifier(nullptr, false, constants, name, nullptr, &itype) && itype != IDENTIFIER_FUNCTION) || has_builtin(p_functions, name, !is_constant)) {
 					_set_redefinition_error(String(name));
 					return ERR_PARSE_ERROR;
 				}
@@ -10260,20 +10318,13 @@ Error ShaderLanguage::_parse_shader(const HashMap<StringName, FunctionInfo> &p_f
 
 				function.callable = !p_functions.has(name);
 				function.name = name;
+				function.rname = name;
 
 				FunctionNode *func_node = alloc_node<FunctionNode>();
-
 				function.function = func_node;
 
-				shader->functions.insert(name, function);
-				shader->vfunctions.push_back(function);
-
-				CallInfo call_info;
-				call_info.name = name;
-
-				calls_info.insert(name, call_info);
-
 				func_node->name = name;
+				func_node->rname = name;
 				func_node->return_type = type;
 				func_node->return_struct_name = struct_name;
 				func_node->return_precision = precision;
@@ -10281,12 +10332,12 @@ Error ShaderLanguage::_parse_shader(const HashMap<StringName, FunctionInfo> &p_f
 
 				if (p_functions.has(name)) {
 					func_node->can_discard = p_functions[name].can_discard;
+				}
+
+				if (!function_overload_count.has(name)) {
+					function_overload_count.insert(name, 0);
 				} else {
-#ifdef DEBUG_ENABLED
-					if (check_warnings && HAS_WARNING(ShaderWarning::UNUSED_FUNCTION_FLAG)) {
-						used_functions.insert(name, Usage(tk_line));
-					}
-#endif // DEBUG_ENABLED
+					function_overload_count[name]++;
 				}
 
 				func_node->body = alloc_node<BlockNode>();
@@ -10466,7 +10517,6 @@ Error ShaderLanguage::_parse_shader(const HashMap<StringName, FunctionInfo> &p_f
 
 					param_name = tk.text;
 
-					ShaderLanguage::IdentifierType itype;
 					if (_find_identifier(func_node->body, false, builtins, param_name, (ShaderLanguage::DataType *)nullptr, &itype)) {
 						if (itype != IDENTIFIER_FUNCTION) {
 							_set_redefinition_error(String(param_name));
@@ -10512,6 +10562,66 @@ Error ShaderLanguage::_parse_shader(const HashMap<StringName, FunctionInfo> &p_f
 					}
 				}
 
+				// Searches for function index and check for the exact duplicate in overloads.
+				int function_index = 0;
+				for (int i = 0; i < shader->vfunctions.size(); i++) {
+					if (!shader->vfunctions[i].callable || shader->vfunctions[i].rname != name) {
+						continue;
+					}
+
+					function_index++;
+
+					if (shader->vfunctions[i].function->arguments.size() != func_node->arguments.size()) {
+						continue;
+					}
+
+					bool is_same = true;
+
+					for (int j = 0; j < shader->vfunctions[i].function->arguments.size(); j++) {
+						FunctionNode::Argument a = func_node->arguments[j];
+						FunctionNode::Argument b = shader->vfunctions[i].function->arguments[j];
+
+						if (a.type == b.type && a.array_size == b.array_size) {
+							if (a.type == TYPE_STRUCT) {
+								is_same = a.struct_name == b.struct_name;
+							}
+						} else {
+							is_same = false;
+						}
+
+						if (!is_same) {
+							break;
+						}
+					}
+
+					if (is_same) {
+						_set_redefinition_error(String(name));
+						return ERR_PARSE_ERROR;
+					}
+				}
+
+				// Creates a fake name for function overload, which will be replaced by the real name by the compiler.
+				String name2 = name;
+				if (function_index > 0) {
+					name2 = vformat("%s@%s", name, itos(function_index + 1));
+
+					function.name = name2;
+					func_node->name = name2;
+				}
+
+				shader->functions.insert(name2, function);
+				shader->vfunctions.push_back(function);
+
+				CallInfo call_info;
+				call_info.name = name2;
+				calls_info.insert(name2, call_info);
+
+#ifdef DEBUG_ENABLED
+				if (check_warnings && HAS_WARNING(ShaderWarning::UNUSED_FUNCTION_FLAG) && !p_functions.has(name)) {
+					used_functions.insert(name2, Usage(tk_line));
+				}
+#endif // DEBUG_ENABLED
+
 				if (p_functions.has(name)) {
 					//if one of the core functions, make sure they are of the correct form
 					if (func_node->arguments.size() > 0) {
@@ -10531,7 +10641,7 @@ Error ShaderLanguage::_parse_shader(const HashMap<StringName, FunctionInfo> &p_f
 					return ERR_PARSE_ERROR;
 				}
 
-				current_function = name;
+				current_function = name2;
 
 #ifdef DEBUG_ENABLED
 				keyword_completion_context = CF_BLOCK;
@@ -11044,7 +11154,7 @@ Error ShaderLanguage::complete(const String &p_code, const ShaderCompileInfo &p_
 					if (!shader->vfunctions[i].callable || shader->vfunctions[i].name == skip_function) {
 						continue;
 					}
-					matches.insert(String(shader->vfunctions[i].name), ScriptLanguage::CODE_COMPLETION_KIND_FUNCTION);
+					matches.insert(String(shader->vfunctions[i].rname), ScriptLanguage::CODE_COMPLETION_KIND_FUNCTION);
 				}
 
 				int idx = 0;
@@ -11095,6 +11205,7 @@ Error ShaderLanguage::complete(const String &p_code, const ShaderCompileInfo &p_
 		case COMPLETION_CALL_ARGUMENTS: {
 			StringName block_function;
 			BlockNode *block = completion_block;
+			String calltip;
 
 			while (block) {
 				if (block->parent_function) {
@@ -11103,85 +11214,83 @@ Error ShaderLanguage::complete(const String &p_code, const ShaderCompileInfo &p_
 				block = block->parent_block;
 			}
 
-			for (int i = 0; i < shader->vfunctions.size(); i++) {
-				if (!shader->vfunctions[i].callable) {
+			for (int i = 0, overload_index = 0; i < shader->vfunctions.size(); i++) {
+				if (!shader->vfunctions[i].callable || shader->vfunctions[i].rname != completion_function) {
 					continue;
 				}
-				if (shader->vfunctions[i].name == completion_function) {
-					String calltip;
-
-					if (shader->vfunctions[i].function->return_type == TYPE_STRUCT) {
-						calltip += String(shader->vfunctions[i].function->return_struct_name);
-					} else {
-						calltip += get_datatype_name(shader->vfunctions[i].function->return_type);
-					}
 
-					if (shader->vfunctions[i].function->return_array_size > 0) {
-						calltip += "[";
-						calltip += itos(shader->vfunctions[i].function->return_array_size);
-						calltip += "]";
-					}
+				if (shader->vfunctions[i].function->return_type == TYPE_STRUCT) {
+					calltip += String(shader->vfunctions[i].function->return_struct_name);
+				} else {
+					calltip += get_datatype_name(shader->vfunctions[i].function->return_type);
+				}
 
-					calltip += " ";
-					calltip += shader->vfunctions[i].name;
-					calltip += "(";
+				if (shader->vfunctions[i].function->return_array_size > 0) {
+					calltip += "[";
+					calltip += itos(shader->vfunctions[i].function->return_array_size);
+					calltip += "]";
+				}
 
-					for (int j = 0; j < shader->vfunctions[i].function->arguments.size(); j++) {
-						if (j > 0) {
-							calltip += ", ";
-						} else {
-							calltip += " ";
-						}
+				calltip += " ";
+				calltip += shader->vfunctions[i].rname;
+				calltip += "(";
 
-						if (j == completion_argument) {
-							calltip += char32_t(0xFFFF);
-						}
+				for (int j = 0; j < shader->vfunctions[i].function->arguments.size(); j++) {
+					if (j > 0) {
+						calltip += ", ";
+					} else {
+						calltip += " ";
+					}
 
-						if (shader->vfunctions[i].function->arguments[j].is_const) {
-							calltip += "const ";
-						}
+					if (j == completion_argument) {
+						calltip += char32_t(0xFFFF);
+					}
 
-						if (shader->vfunctions[i].function->arguments[j].qualifier != ArgumentQualifier::ARGUMENT_QUALIFIER_IN) {
-							if (shader->vfunctions[i].function->arguments[j].qualifier == ArgumentQualifier::ARGUMENT_QUALIFIER_OUT) {
-								calltip += "out ";
-							} else { // ArgumentQualifier::ARGUMENT_QUALIFIER_INOUT
-								calltip += "inout ";
-							}
-						}
+					if (shader->vfunctions[i].function->arguments[j].is_const) {
+						calltip += "const ";
+					}
 
-						if (shader->vfunctions[i].function->arguments[j].type == TYPE_STRUCT) {
-							calltip += String(shader->vfunctions[i].function->arguments[j].struct_name);
-						} else {
-							calltip += get_datatype_name(shader->vfunctions[i].function->arguments[j].type);
+					if (shader->vfunctions[i].function->arguments[j].qualifier != ArgumentQualifier::ARGUMENT_QUALIFIER_IN) {
+						if (shader->vfunctions[i].function->arguments[j].qualifier == ArgumentQualifier::ARGUMENT_QUALIFIER_OUT) {
+							calltip += "out ";
+						} else { // ArgumentQualifier::ARGUMENT_QUALIFIER_INOUT
+							calltip += "inout ";
 						}
-						calltip += " ";
-						calltip += shader->vfunctions[i].function->arguments[j].name;
+					}
 
-						if (shader->vfunctions[i].function->arguments[j].array_size > 0) {
-							calltip += "[";
-							calltip += itos(shader->vfunctions[i].function->arguments[j].array_size);
-							calltip += "]";
-						}
+					if (shader->vfunctions[i].function->arguments[j].type == TYPE_STRUCT) {
+						calltip += String(shader->vfunctions[i].function->arguments[j].struct_name);
+					} else {
+						calltip += get_datatype_name(shader->vfunctions[i].function->arguments[j].type);
+					}
+					calltip += " ";
+					calltip += shader->vfunctions[i].function->arguments[j].name;
 
-						if (j == completion_argument) {
-							calltip += char32_t(0xFFFF);
-						}
+					if (shader->vfunctions[i].function->arguments[j].array_size > 0) {
+						calltip += "[";
+						calltip += itos(shader->vfunctions[i].function->arguments[j].array_size);
+						calltip += "]";
 					}
 
-					if (shader->vfunctions[i].function->arguments.size()) {
-						calltip += " ";
+					if (j == completion_argument) {
+						calltip += char32_t(0xFFFF);
 					}
-					calltip += ")";
+				}
 
-					r_call_hint = calltip;
-					return OK;
+				if (shader->vfunctions[i].function->arguments.size()) {
+					calltip += " ";
 				}
-			}
+				calltip += ")";
 
-			int idx = 0;
+				if (overload_index < function_overload_count[shader->vfunctions[i].rname]) {
+					overload_index++;
+					calltip += "\n";
+					continue;
+				}
 
-			String calltip;
-			bool low_end = RenderingServer::get_singleton()->is_low_end();
+				r_call_hint = calltip;
+				return OK;
+			}
 
 			if (stages && stages->has(block_function)) {
 				for (const KeyValue<StringName, StageFunctionInfo> &E : (*stages)[block_function].stage_functions) {
@@ -11222,6 +11331,9 @@ Error ShaderLanguage::complete(const String &p_code, const ShaderCompileInfo &p_
 				}
 			}
 
+			int idx = 0;
+			bool low_end = RenderingServer::get_singleton()->is_low_end();
+
 			while (builtin_func_defs[idx].name) {
 				if ((low_end && builtin_func_defs[idx].high_end) || _check_restricted_func(builtin_func_defs[idx].name, block_function)) {
 					idx++;

+ 4 - 0
servers/rendering/shader_language.h

@@ -427,6 +427,7 @@ public:
 	struct VariableNode : public Node {
 		DataType datatype_cache = TYPE_VOID;
 		StringName name;
+		StringName rname;
 		StringName struct_name;
 		bool is_const = false;
 		bool is_local = false;
@@ -604,6 +605,7 @@ public:
 
 		struct Function {
 			StringName name;
+			StringName rname;
 			FunctionNode *function = nullptr;
 			HashSet<StringName> uses_function;
 			bool callable;
@@ -729,6 +731,7 @@ public:
 		};
 
 		StringName name;
+		StringName rname;
 		DataType return_type = TYPE_VOID;
 		StringName return_struct_name;
 		DataPrecision return_precision = PRECISION_DEFAULT;
@@ -944,6 +947,7 @@ private:
 
 	Vector<FilePosition> include_positions;
 	HashSet<String> include_markers_handled;
+	HashMap<StringName, int> function_overload_count;
 
 	// Additional function information (eg. call hierarchy). No need to expose it to compiler.
 	struct CallInfo {