Browse Source

Forbid calling of derivative functions in incorrect functions

Yuri Rubinsky 1 year ago
parent
commit
ae95531e64
2 changed files with 124 additions and 2 deletions
  1. 106 2
      servers/rendering/shader_language.cpp
  2. 18 0
      servers/rendering/shader_language.h

+ 106 - 2
servers/rendering/shader_language.cpp

@@ -1238,6 +1238,7 @@ void ShaderLanguage::clear() {
 	include_positions.push_back(FilePosition());
 	include_positions.push_back(FilePosition());
 
 
 	include_markers_handled.clear();
 	include_markers_handled.clear();
+	calls_info.clear();
 
 
 #ifdef DEBUG_ENABLED
 #ifdef DEBUG_ENABLED
 	keyword_completion_context = CF_UNSPECIFIED;
 	keyword_completion_context = CF_UNSPECIFIED;
@@ -3085,6 +3086,19 @@ const ShaderLanguage::BuiltinFuncConstArgs ShaderLanguage::builtin_func_const_ar
 	{ nullptr, 0, 0, 0 }
 	{ nullptr, 0, 0, 0 }
 };
 };
 
 
+const ShaderLanguage::BuiltinEntry ShaderLanguage::frag_only_func_defs[] = {
+	{ "dFdx" },
+	{ "dFdxCoarse" },
+	{ "dFdxFine" },
+	{ "dFdy" },
+	{ "dFdyCoarse" },
+	{ "dFdyFine" },
+	{ "fwidth" },
+	{ "fwidthCoarse" },
+	{ "fwidthFine" },
+	{ nullptr }
+};
+
 bool ShaderLanguage::is_const_suffix_lut_initialized = false;
 bool ShaderLanguage::is_const_suffix_lut_initialized = false;
 
 
 bool ShaderLanguage::_validate_function_call(BlockNode *p_block, const FunctionInfo &p_function_info, OperatorNode *p_func, DataType *r_ret_type, StringName *r_ret_type_str, bool *r_is_custom_function) {
 bool ShaderLanguage::_validate_function_call(BlockNode *p_block, const FunctionInfo &p_function_info, OperatorNode *p_func, DataType *r_ret_type, StringName *r_ret_type_str, bool *r_is_custom_function) {
@@ -4610,6 +4624,58 @@ bool ShaderLanguage::_check_node_constness(const Node *p_node) const {
 	return true;
 	return true;
 }
 }
 
 
+bool ShaderLanguage::_check_restricted_func(const StringName &p_name, const StringName &p_current_function) const {
+	int idx = 0;
+
+	while (frag_only_func_defs[idx].name) {
+		if (StringName(frag_only_func_defs[idx].name) == p_name) {
+			if (is_supported_frag_only_funcs) {
+				if (p_current_function == "vertex" && stages->has(p_current_function)) {
+					return true;
+				}
+			} else {
+				return true;
+			}
+			break;
+		}
+		idx++;
+	}
+
+	return false;
+}
+
+bool ShaderLanguage::_validate_restricted_func(const StringName &p_name, const CallInfo *p_func_info, bool p_is_builtin_hint) {
+	const bool is_in_restricted_function = p_func_info->name == "vertex";
+
+	// No need to check up the hierarchy if it's a built-in.
+	if (!p_is_builtin_hint) {
+		for (const CallInfo *func_info : p_func_info->calls) {
+			if (is_in_restricted_function && func_info->name != p_name) {
+				// Skips check for non-called method.
+				continue;
+			}
+
+			if (!_validate_restricted_func(p_name, func_info)) {
+				return false;
+			}
+		}
+	}
+
+	if (!p_func_info->uses_restricted_functions.is_empty()) {
+		const Pair<StringName, TkPos> &first_element = p_func_info->uses_restricted_functions.get(0);
+		_set_tkpos(first_element.second);
+
+		if (is_in_restricted_function) {
+			_set_error(vformat(RTR("'%s' cannot be used within the '%s' processor function."), first_element.first, "vertex"));
+		} else {
+			_set_error(vformat(RTR("'%s' cannot be used here, because '%s' is called by the '%s' processor function (which is not allowed)."), first_element.first, p_func_info->name, "vertex"));
+		}
+		return false;
+	}
+
+	return true;
+}
+
 bool ShaderLanguage::_validate_assign(Node *p_node, const FunctionInfo &p_function_info, String *r_message) {
 bool ShaderLanguage::_validate_assign(Node *p_node, const FunctionInfo &p_function_info, String *r_message) {
 	if (p_node->type == Node::NODE_TYPE_OPERATOR) {
 	if (p_node->type == Node::NODE_TYPE_OPERATOR) {
 		OperatorNode *op = static_cast<OperatorNode *>(p_node);
 		OperatorNode *op = static_cast<OperatorNode *>(p_node);
@@ -5266,6 +5332,36 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons
 
 
 					const StringName &name = identifier;
 					const StringName &name = identifier;
 
 
+					if (name != current_function) { // Recursion is not allowed.
+						// Register call.
+						if (calls_info.has(name)) {
+							calls_info[current_function].calls.push_back(&calls_info[name]);
+						}
+
+						int idx = 0;
+						bool is_builtin = false;
+
+						while (frag_only_func_defs[idx].name) {
+							if (frag_only_func_defs[idx].name == name) {
+								// If a built-in function not found for the current shader type, then it shouldn't be parsed further.
+								if (!is_supported_frag_only_funcs) {
+									_set_error(vformat(RTR("Built-in function '%s' is not supported for the '%s' shader type."), name, shader_type_identifier));
+									return nullptr;
+								}
+								// Register usage of the restricted function.
+								calls_info[current_function].uses_restricted_functions.push_back(Pair<StringName, TkPos>(name, _get_tkpos()));
+								is_builtin = true;
+								break;
+							}
+							idx++;
+						}
+
+						// Recursively checks for the restricted function call.
+						if (is_supported_frag_only_funcs && current_function == "vertex" && stages->has(current_function) && !_validate_restricted_func(name, &calls_info[current_function], is_builtin)) {
+							return nullptr;
+						}
+					}
+
 					OperatorNode *func = alloc_node<OperatorNode>();
 					OperatorNode *func = alloc_node<OperatorNode>();
 					func->op = OP_CALL;
 					func->op = OP_CALL;
 					VariableNode *funcname = alloc_node<VariableNode>();
 					VariableNode *funcname = alloc_node<VariableNode>();
@@ -8099,6 +8195,8 @@ Error ShaderLanguage::_parse_shader(const HashMap<StringName, FunctionInfo> &p_f
 	ShaderNode::Uniform::Scope uniform_scope = ShaderNode::Uniform::SCOPE_LOCAL;
 	ShaderNode::Uniform::Scope uniform_scope = ShaderNode::Uniform::SCOPE_LOCAL;
 
 
 	stages = &p_functions;
 	stages = &p_functions;
+	is_supported_frag_only_funcs = shader_type_identifier == "canvas_item" || shader_type_identifier == "spatial" || shader_type_identifier == "sky";
+
 	const FunctionInfo &constants = p_functions.has("constants") ? p_functions["constants"] : FunctionInfo();
 	const FunctionInfo &constants = p_functions.has("constants") ? p_functions["constants"] : FunctionInfo();
 
 
 	HashMap<String, String> defined_modes;
 	HashMap<String, String> defined_modes;
@@ -9541,6 +9639,11 @@ Error ShaderLanguage::_parse_shader(const HashMap<StringName, FunctionInfo> &p_f
 				shader->functions.insert(name, function);
 				shader->functions.insert(name, function);
 				shader->vfunctions.push_back(function);
 				shader->vfunctions.push_back(function);
 
 
+				CallInfo call_info;
+				call_info.name = name;
+
+				calls_info.insert(name, call_info);
+
 				func_node->name = name;
 				func_node->name = name;
 				func_node->return_type = type;
 				func_node->return_type = type;
 				func_node->return_struct_name = struct_name;
 				func_node->return_struct_name = struct_name;
@@ -10325,10 +10428,11 @@ Error ShaderLanguage::complete(const String &p_code, const ShaderCompileInfo &p_
 				}
 				}
 
 
 				while (builtin_func_defs[idx].name) {
 				while (builtin_func_defs[idx].name) {
-					if (low_end && builtin_func_defs[idx].high_end) {
+					if ((low_end && builtin_func_defs[idx].high_end) || _check_restricted_func(builtin_func_defs[idx].name, skip_function)) {
 						idx++;
 						idx++;
 						continue;
 						continue;
 					}
 					}
+
 					matches.insert(String(builtin_func_defs[idx].name), ScriptLanguage::CODE_COMPLETION_KIND_FUNCTION);
 					matches.insert(String(builtin_func_defs[idx].name), ScriptLanguage::CODE_COMPLETION_KIND_FUNCTION);
 					idx++;
 					idx++;
 				}
 				}
@@ -10490,7 +10594,7 @@ Error ShaderLanguage::complete(const String &p_code, const ShaderCompileInfo &p_
 			}
 			}
 
 
 			while (builtin_func_defs[idx].name) {
 			while (builtin_func_defs[idx].name) {
-				if (low_end && builtin_func_defs[idx].high_end) {
+				if ((low_end && builtin_func_defs[idx].high_end) || _check_restricted_func(builtin_func_defs[idx].name, block_function)) {
 					idx++;
 					idx++;
 					continue;
 					continue;
 				}
 				}

+ 18 - 0
servers/rendering/shader_language.h

@@ -913,6 +913,15 @@ private:
 	Vector<FilePosition> include_positions;
 	Vector<FilePosition> include_positions;
 	HashSet<String> include_markers_handled;
 	HashSet<String> include_markers_handled;
 
 
+	// Additional function information (eg. call hierarchy). No need to expose it to compiler.
+	struct CallInfo {
+		StringName name;
+		List<Pair<StringName, TkPos>> uses_restricted_functions;
+		List<CallInfo *> calls;
+	};
+
+	RBMap<StringName, CallInfo> calls_info;
+
 #ifdef DEBUG_ENABLED
 #ifdef DEBUG_ENABLED
 	struct Usage {
 	struct Usage {
 		int decl_line;
 		int decl_line;
@@ -1036,6 +1045,10 @@ private:
 	bool _validate_assign(Node *p_node, const FunctionInfo &p_function_info, String *r_message = nullptr);
 	bool _validate_assign(Node *p_node, const FunctionInfo &p_function_info, String *r_message = nullptr);
 	bool _validate_operator(OperatorNode *p_op, DataType *r_ret_type = nullptr, int *r_ret_size = nullptr);
 	bool _validate_operator(OperatorNode *p_op, DataType *r_ret_type = nullptr, int *r_ret_size = nullptr);
 
 
+	struct BuiltinEntry {
+		const char *name;
+	};
+
 	struct BuiltinFuncDef {
 	struct BuiltinFuncDef {
 		enum { MAX_ARGS = 5 };
 		enum { MAX_ARGS = 5 };
 		const char *name;
 		const char *name;
@@ -1078,11 +1091,13 @@ private:
 #endif // DEBUG_ENABLED
 #endif // DEBUG_ENABLED
 
 
 	const HashMap<StringName, FunctionInfo> *stages = nullptr;
 	const HashMap<StringName, FunctionInfo> *stages = nullptr;
+	bool is_supported_frag_only_funcs = false;
 
 
 	bool _get_completable_identifier(BlockNode *p_block, CompletionType p_type, StringName &identifier);
 	bool _get_completable_identifier(BlockNode *p_block, CompletionType p_type, StringName &identifier);
 	static const BuiltinFuncDef builtin_func_defs[];
 	static const BuiltinFuncDef builtin_func_defs[];
 	static const BuiltinFuncOutArgs builtin_func_out_args[];
 	static const BuiltinFuncOutArgs builtin_func_out_args[];
 	static const BuiltinFuncConstArgs builtin_func_const_args[];
 	static const BuiltinFuncConstArgs builtin_func_const_args[];
+	static const BuiltinEntry frag_only_func_defs[];
 
 
 	static bool is_const_suffix_lut_initialized;
 	static bool is_const_suffix_lut_initialized;
 
 
@@ -1097,6 +1112,9 @@ private:
 	bool _validate_varying_assign(ShaderNode::Varying &p_varying, String *r_message);
 	bool _validate_varying_assign(ShaderNode::Varying &p_varying, String *r_message);
 	bool _check_node_constness(const Node *p_node) const;
 	bool _check_node_constness(const Node *p_node) const;
 
 
+	bool _check_restricted_func(const StringName &p_name, const StringName &p_current_function) const;
+	bool _validate_restricted_func(const StringName &p_call_name, const CallInfo *p_func_info, bool p_is_builtin_hint = false);
+
 	Node *_parse_expression(BlockNode *p_block, const FunctionInfo &p_function_info);
 	Node *_parse_expression(BlockNode *p_block, const FunctionInfo &p_function_info);
 	Error _parse_array_size(BlockNode *p_block, const FunctionInfo &p_function_info, bool p_forbid_unknown_size, Node **r_size_expression, int *r_array_size, bool *r_unknown_size);
 	Error _parse_array_size(BlockNode *p_block, const FunctionInfo &p_function_info, bool p_forbid_unknown_size, Node **r_size_expression, int *r_array_size, bool *r_unknown_size);
 	Node *_parse_array_constructor(BlockNode *p_block, const FunctionInfo &p_function_info);
 	Node *_parse_array_constructor(BlockNode *p_block, const FunctionInfo &p_function_info);