Prechádzať zdrojové kódy

Merge pull request #102289 from Chaosus/shader_fix_constants_in_hint_range

[Shaders] Allow constants and expressions in `hint_range`
Thaddeus Crews 4 mesiacov pred
rodič
commit
7e7c5244c2

+ 58 - 49
servers/rendering/shader_language.cpp

@@ -1399,6 +1399,9 @@ bool ShaderLanguage::_find_identifier(const BlockNode *p_block, bool p_allow_rea
 			if (r_is_const) {
 				*r_is_const = p_function_info.built_ins[p_identifier].constant;
 			}
+			if (r_constant_values) {
+				*r_constant_values = p_function_info.built_ins[p_identifier].values;
+			}
 			if (r_type) {
 				*r_type = IDENTIFIER_BUILTIN_VAR;
 			}
@@ -1551,7 +1554,7 @@ bool ShaderLanguage::_find_identifier(const BlockNode *p_block, bool p_allow_rea
 	return false;
 }
 
-bool ShaderLanguage::_validate_operator(const BlockNode *p_block, OperatorNode *p_op, DataType *r_ret_type, int *r_ret_size) {
+bool ShaderLanguage::_validate_operator(const BlockNode *p_block, const FunctionInfo &p_function_info, OperatorNode *p_op, DataType *r_ret_type, int *r_ret_size) {
 	bool valid = false;
 	DataType ret_type = TYPE_VOID;
 	int ret_size = 0;
@@ -2017,18 +2020,18 @@ bool ShaderLanguage::_validate_operator(const BlockNode *p_block, OperatorNode *
 
 	if (valid && (!p_block || p_block->use_op_eval)) {
 		// Need to be placed here and not in the `_reduce_expression` because otherwise expressions like `1 + 2 / 2` will not work correctly.
-		valid = _eval_operator(p_block, p_op);
+		valid = _eval_operator(p_block, p_function_info, p_op);
 	}
 
 	return valid;
 }
 
-Vector<ShaderLanguage::Scalar> ShaderLanguage::_get_node_values(const BlockNode *p_block, Node *p_node) {
+Vector<ShaderLanguage::Scalar> ShaderLanguage::_get_node_values(const BlockNode *p_block, const FunctionInfo &p_function_info, Node *p_node) {
 	Vector<Scalar> result;
 
 	switch (p_node->type) {
 		case Node::NODE_TYPE_VARIABLE: {
-			_find_identifier(p_block, false, FunctionInfo(), static_cast<VariableNode *>(p_node)->name, nullptr, nullptr, nullptr, nullptr, nullptr, &result);
+			_find_identifier(p_block, false, p_function_info, static_cast<VariableNode *>(p_node)->name, nullptr, nullptr, nullptr, nullptr, nullptr, &result);
 		} break;
 		default: {
 			result = p_node->get_values();
@@ -2038,7 +2041,7 @@ Vector<ShaderLanguage::Scalar> ShaderLanguage::_get_node_values(const BlockNode
 	return result;
 }
 
-bool ShaderLanguage::_eval_operator(const BlockNode *p_block, OperatorNode *p_op) {
+bool ShaderLanguage::_eval_operator(const BlockNode *p_block, const FunctionInfo &p_function_info, OperatorNode *p_op) {
 	bool is_valid = true;
 
 	switch (p_op->op) {
@@ -2080,8 +2083,8 @@ bool ShaderLanguage::_eval_operator(const BlockNode *p_block, OperatorNode *p_op
 				}
 			}
 
-			Vector<Scalar> va = _get_node_values(p_block, p_op->arguments[0]);
-			Vector<Scalar> vb = _get_node_values(p_block, p_op->arguments[1]);
+			Vector<Scalar> va = _get_node_values(p_block, p_function_info, p_op->arguments[0]);
+			Vector<Scalar> vb = _get_node_values(p_block, p_function_info, p_op->arguments[1]);
 
 			if (is_op_vec_transform) {
 				p_op->values = _eval_vector_transform(va, vb, a, b, p_op->get_datatype());
@@ -2092,7 +2095,7 @@ bool ShaderLanguage::_eval_operator(const BlockNode *p_block, OperatorNode *p_op
 		case OP_NOT:
 		case OP_NEGATE:
 		case OP_BIT_INVERT: {
-			p_op->values = _eval_unary_vector(_get_node_values(p_block, p_op->arguments[0]), p_op->get_datatype(), p_op->op);
+			p_op->values = _eval_unary_vector(_get_node_values(p_block, p_function_info, p_op->arguments[0]), p_op->get_datatype(), p_op->op);
 		} break;
 		default: {
 		} break;
@@ -3661,7 +3664,7 @@ bool ShaderLanguage::_validate_function_call(BlockNode *p_block, const FunctionI
 								int max = builtin_func_const_args[constarg_idx].max;
 
 								bool error = false;
-								Vector<Scalar> values = _get_node_values(p_block, p_func->arguments[arg]);
+								Vector<Scalar> values = _get_node_values(p_block, p_function_info, p_func->arguments[arg]);
 								if (p_func->arguments[arg]->get_datatype() == TYPE_INT && !values.is_empty()) {
 									if (values[0].sint < min || values[0].sint > max) {
 										error = true;
@@ -5662,7 +5665,7 @@ Error ShaderLanguage::_parse_array_size(BlockNode *p_block, const FunctionInfo &
 		Node *expr = _parse_and_reduce_expression(p_block, p_function_info);
 
 		if (expr) {
-			Vector<Scalar> values = _get_node_values(p_block, expr);
+			Vector<Scalar> values = _get_node_values(p_block, p_function_info, expr);
 
 			if (!values.is_empty()) {
 				switch (expr->get_datatype()) {
@@ -7279,7 +7282,7 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons
 				op->op = tk.type == TK_OP_DECREMENT ? OP_POST_DECREMENT : OP_POST_INCREMENT;
 				op->arguments.push_back(expr);
 
-				if (!_validate_operator(p_block, op, &op->return_cache, &op->return_array_size)) {
+				if (!_validate_operator(p_block, p_function_info, op, &op->return_cache, &op->return_array_size)) {
 					_set_error(RTR("Invalid base type for increment/decrement operator."));
 					return nullptr;
 				}
@@ -7631,7 +7634,7 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons
 				expression.write[i].is_op = false;
 				expression.write[i].node = op;
 
-				if (!_validate_operator(p_block, op, &op->return_cache, &op->return_array_size)) {
+				if (!_validate_operator(p_block, p_function_info, op, &op->return_cache, &op->return_array_size)) {
 					if (error_set) {
 						return nullptr;
 					}
@@ -7673,7 +7676,7 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons
 
 			expression.write[next_op - 1].is_op = false;
 			expression.write[next_op - 1].node = op;
-			if (!_validate_operator(p_block, op, &op->return_cache, &op->return_array_size)) {
+			if (!_validate_operator(p_block, p_function_info, op, &op->return_cache, &op->return_array_size)) {
 				if (error_set) {
 					return nullptr;
 				}
@@ -7740,7 +7743,7 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons
 
 			//replace all 3 nodes by this operator and make it an expression
 
-			if (!_validate_operator(p_block, op, &op->return_cache, &op->return_array_size)) {
+			if (!_validate_operator(p_block, p_function_info, op, &op->return_cache, &op->return_array_size)) {
 				if (error_set) {
 					return nullptr;
 				}
@@ -9073,6 +9076,40 @@ Error ShaderLanguage::_validate_precision(DataType p_type, DataPrecision p_preci
 	return OK;
 }
 
+bool ShaderLanguage::_parse_numeric_constant_expression(const FunctionInfo &p_function_info, float &r_constant) {
+	ShaderLanguage::Node *expr = _parse_and_reduce_expression(nullptr, p_function_info);
+	if (expr == nullptr) {
+		return false;
+	}
+
+	Vector<Scalar> values;
+	if (expr->type == Node::NODE_TYPE_VARIABLE) {
+		_find_identifier(nullptr, false, p_function_info, static_cast<VariableNode *>(expr)->name, nullptr, nullptr, nullptr, nullptr, nullptr, &values);
+	} else {
+		values = expr->get_values();
+	}
+
+	if (values.is_empty()) {
+		return false; // To prevent possible crash.
+	}
+
+	switch (expr->get_datatype()) {
+		case TYPE_FLOAT:
+			r_constant = values[0].real;
+			break;
+		case TYPE_INT:
+			r_constant = static_cast<float>(values[0].sint);
+			break;
+		case TYPE_UINT:
+			r_constant = static_cast<float>(values[0].uint);
+			break;
+		default:
+			return false;
+	}
+
+	return true;
+}
+
 Error ShaderLanguage::_parse_shader(const HashMap<StringName, FunctionInfo> &p_functions, const Vector<ModeInfo> &p_render_modes, const HashSet<String> &p_shader_types) {
 	Token tk;
 	TkPos prev_pos;
@@ -9817,58 +9854,30 @@ Error ShaderLanguage::_parse_shader(const HashMap<StringName, FunctionInfo> &p_f
 										return ERR_PARSE_ERROR;
 									}
 
-									tk = _get_token();
-
-									float sign = 1.0;
-
-									if (tk.type == TK_OP_SUB) {
-										sign = -1.0;
-										tk = _get_token();
-									}
-
-									if (tk.type != TK_FLOAT_CONSTANT && !tk.is_integer_constant()) {
-										_set_error(RTR("Expected an integer constant."));
+									if (!_parse_numeric_constant_expression(constants, uniform.hint_range[0])) {
+										_set_error(RTR("Expected a valid numeric expression."));
 										return ERR_PARSE_ERROR;
 									}
 
-									uniform.hint_range[0] = tk.constant;
-									uniform.hint_range[0] *= sign;
-
 									tk = _get_token();
 
 									if (tk.type != TK_COMMA) {
-										_set_error(RTR("Expected ',' after integer constant."));
+										_set_expected_error(",");
 										return ERR_PARSE_ERROR;
 									}
 
-									tk = _get_token();
-
-									sign = 1.0;
-
-									if (tk.type == TK_OP_SUB) {
-										sign = -1.0;
-										tk = _get_token();
-									}
-
-									if (tk.type != TK_FLOAT_CONSTANT && !tk.is_integer_constant()) {
-										_set_error(RTR("Expected an integer constant after ','."));
+									if (!_parse_numeric_constant_expression(constants, uniform.hint_range[1])) {
+										_set_error(RTR("Expected a valid numeric expression after ','."));
 										return ERR_PARSE_ERROR;
 									}
 
-									uniform.hint_range[1] = tk.constant;
-									uniform.hint_range[1] *= sign;
-
 									tk = _get_token();
 
 									if (tk.type == TK_COMMA) {
-										tk = _get_token();
-
-										if (tk.type != TK_FLOAT_CONSTANT && !tk.is_integer_constant()) {
-											_set_error(RTR("Expected an integer constant after ','."));
+										if (!_parse_numeric_constant_expression(constants, uniform.hint_range[2])) {
+											_set_error(RTR("Expected a valid numeric expression after ','."));
 											return ERR_PARSE_ERROR;
 										}
-
-										uniform.hint_range[2] = tk.constant;
 										tk = _get_token();
 									} else {
 										if (type == TYPE_INT) {

+ 8 - 5
servers/rendering/shader_language.h

@@ -862,12 +862,14 @@ public:
 	struct BuiltInInfo {
 		DataType type = TYPE_VOID;
 		bool constant = false;
+		Vector<Scalar> values;
 
 		BuiltInInfo() {}
 
-		BuiltInInfo(DataType p_type, bool p_constant = false) :
+		BuiltInInfo(DataType p_type, bool p_constant = false, const Vector<Scalar> &p_values = {}) :
 				type(p_type),
-				constant(p_constant) {}
+				constant(p_constant),
+				values(p_values) {}
 	};
 
 	struct StageFunctionInfo {
@@ -1119,10 +1121,10 @@ private:
 #endif // DEBUG_ENABLED
 	bool _is_operator_assign(Operator p_op) const;
 	bool _validate_assign(Node *p_node, const FunctionInfo &p_function_info, String *r_message = nullptr);
-	bool _validate_operator(const BlockNode *p_block, OperatorNode *p_op, DataType *r_ret_type = nullptr, int *r_ret_size = nullptr);
+	bool _validate_operator(const BlockNode *p_block, const FunctionInfo &p_function_info, OperatorNode *p_op, DataType *r_ret_type = nullptr, int *r_ret_size = nullptr);
 
-	Vector<Scalar> _get_node_values(const BlockNode *p_block, Node *p_node);
-	bool _eval_operator(const BlockNode *p_block, OperatorNode *p_op);
+	Vector<Scalar> _get_node_values(const BlockNode *p_block, const FunctionInfo &p_function_info, Node *p_node);
+	bool _eval_operator(const BlockNode *p_block, const FunctionInfo &p_function_info, OperatorNode *p_op);
 	Scalar _eval_unary_scalar(const Scalar &p_a, Operator p_op, DataType p_ret_type);
 	Scalar _eval_scalar(const Scalar &p_a, const Scalar &p_b, Operator p_op, DataType p_ret_type, bool &r_is_valid);
 	Vector<Scalar> _eval_unary_vector(const Vector<Scalar> &p_va, DataType p_ret_type, Operator p_op);
@@ -1212,6 +1214,7 @@ private:
 	String _get_shader_type_list(const HashSet<String> &p_shader_types) const;
 	String _get_qualifier_str(ArgumentQualifier p_qualifier) const;
 
+	bool _parse_numeric_constant_expression(const FunctionInfo &p_function_info, float &r_constant);
 	Error _parse_shader(const HashMap<StringName, FunctionInfo> &p_functions, const Vector<ModeInfo> &p_render_modes, const HashSet<String> &p_shader_types);
 
 	Error _find_last_flow_op_in_block(BlockNode *p_block, FlowOperation p_op);

+ 16 - 3
servers/rendering/shader_types.cpp

@@ -52,16 +52,29 @@ static ShaderLanguage::BuiltInInfo constt(ShaderLanguage::DataType p_type) {
 	return ShaderLanguage::BuiltInInfo(p_type, true);
 }
 
+static ShaderLanguage::BuiltInInfo constvt(ShaderLanguage::DataType p_type, const Vector<ShaderLanguage::Scalar> &p_values) {
+	return ShaderLanguage::BuiltInInfo(p_type, true, p_values);
+}
+
 ShaderTypes::ShaderTypes() {
 	singleton = this;
 
 	/*************** SPATIAL ***********************/
 
+	ShaderLanguage::Scalar pi_scalar;
+	pi_scalar.real = Math::PI;
+
+	ShaderLanguage::Scalar tau_scalar;
+	tau_scalar.real = Math::TAU;
+
+	ShaderLanguage::Scalar e_scalar;
+	e_scalar.real = Math::E;
+
 	shader_modes[RS::SHADER_SPATIAL].functions["global"].built_ins["TIME"] = constt(ShaderLanguage::TYPE_FLOAT);
 	shader_modes[RS::SHADER_SPATIAL].functions["global"].built_ins["EXPOSURE"] = constt(ShaderLanguage::TYPE_FLOAT);
-	shader_modes[RS::SHADER_SPATIAL].functions["constants"].built_ins["PI"] = constt(ShaderLanguage::TYPE_FLOAT);
-	shader_modes[RS::SHADER_SPATIAL].functions["constants"].built_ins["TAU"] = constt(ShaderLanguage::TYPE_FLOAT);
-	shader_modes[RS::SHADER_SPATIAL].functions["constants"].built_ins["E"] = constt(ShaderLanguage::TYPE_FLOAT);
+	shader_modes[RS::SHADER_SPATIAL].functions["constants"].built_ins["PI"] = constvt(ShaderLanguage::TYPE_FLOAT, { pi_scalar });
+	shader_modes[RS::SHADER_SPATIAL].functions["constants"].built_ins["TAU"] = constvt(ShaderLanguage::TYPE_FLOAT, { tau_scalar });
+	shader_modes[RS::SHADER_SPATIAL].functions["constants"].built_ins["E"] = constvt(ShaderLanguage::TYPE_FLOAT, { e_scalar });
 	shader_modes[RS::SHADER_SPATIAL].functions["constants"].built_ins["OUTPUT_IS_SRGB"] = constt(ShaderLanguage::TYPE_BOOL);
 	shader_modes[RS::SHADER_SPATIAL].functions["constants"].built_ins["CLIP_SPACE_FAR"] = constt(ShaderLanguage::TYPE_FLOAT);