|
@@ -1,3 +1,4 @@
|
|
|
|
+local types = require 'types'
|
|
local rules = {}
|
|
local rules = {}
|
|
|
|
|
|
local function checkValueMatchesType(valueNode, schemaType)
|
|
local function checkValueMatchesType(valueNode, schemaType)
|
|
@@ -453,4 +454,126 @@ function rules.variablesAreDefined(node, context)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
|
|
|
|
+function rules.variableUsageAllowed(node, context)
|
|
|
|
+ if context.currentOperation then
|
|
|
|
+ local variableMap = {}
|
|
|
|
+ for _, definition in ipairs(context.currentOperation.variableDefinitions or {}) do
|
|
|
|
+ variableMap[definition.variable.name.value] = definition
|
|
|
|
+ end
|
|
|
|
+
|
|
|
|
+ local arguments
|
|
|
|
+
|
|
|
|
+ if node.kind == 'field' then
|
|
|
|
+ arguments = { [node.name.value] = node.arguments }
|
|
|
|
+ elseif node.kind == 'fragmentSpread' then
|
|
|
|
+ local function collectArguments(referencedNode)
|
|
|
|
+ if referencedNode.kind == 'selectionSet' then
|
|
|
|
+ for _, selection in ipairs(referencedNode.selections) do
|
|
|
|
+ collectArguments(selection)
|
|
|
|
+ end
|
|
|
|
+ elseif referencedNode.kind == 'field' and referencedNode.arguments then
|
|
|
|
+ local fieldName = referencedNode.name.value
|
|
|
|
+ arguments[fieldName] = arguments[fieldName] or {}
|
|
|
|
+ for _, argument in ipairs(referencedNode.arguments) do
|
|
|
|
+ table.insert(arguments[fieldName], argument)
|
|
|
|
+ end
|
|
|
|
+ elseif referencedNode.kind == 'inlineFragment' then
|
|
|
|
+ return collectArguments(referencedNode.selectionSet)
|
|
|
|
+ elseif referencedNode.kind == 'fragmentSpread' then
|
|
|
|
+ local fragment = context.fragmentMap[referencedNode.name.value]
|
|
|
|
+ return fragment and collectArguments(fragment.selectionSet)
|
|
|
|
+ end
|
|
|
|
+ end
|
|
|
|
+
|
|
|
|
+ local fragment = context.fragmentMap[node.name.value]
|
|
|
|
+ if fragment then
|
|
|
|
+ arguments = {}
|
|
|
|
+ collectArguments(fragment.selectionSet)
|
|
|
|
+ end
|
|
|
|
+ end
|
|
|
|
+
|
|
|
|
+ if not arguments then return end
|
|
|
|
+
|
|
|
|
+ for field in pairs(arguments) do
|
|
|
|
+ local parentField = context.objects[#context.objects - 1].fields[field]
|
|
|
|
+ for i = 1, #arguments[field] do
|
|
|
|
+ local argument = arguments[field][i]
|
|
|
|
+ if argument.value.kind == 'variable' then
|
|
|
|
+ local argumentType = parentField.arguments[argument.name.value]
|
|
|
|
+
|
|
|
|
+ local variableName = argument.value.name.value
|
|
|
|
+ local variableDefinition = variableMap[variableName]
|
|
|
|
+ local hasDefault = variableDefinition.defaultValue ~= nil
|
|
|
|
+
|
|
|
|
+ local function typeFromAST(variable)
|
|
|
|
+ local innerType
|
|
|
|
+ if variable.kind == 'listType' then
|
|
|
|
+ innerType = typeFromAST(variable.type)
|
|
|
|
+ return innerType and types.list(innerType)
|
|
|
|
+ elseif variable.kind == 'nonNullType' then
|
|
|
|
+ innerType = typeFromAST(variable.type)
|
|
|
|
+ return innerType and types.nonNull(innerType)
|
|
|
|
+ else
|
|
|
|
+ assert(variable.kind == 'namedType', 'Variable must be a named type')
|
|
|
|
+ return context.schema:getType(variable.name.value)
|
|
|
|
+ end
|
|
|
|
+ end
|
|
|
|
+
|
|
|
|
+ local variableType = typeFromAST(variableDefinition.type)
|
|
|
|
+
|
|
|
|
+ if hasDefault and variableType.__type ~= 'NonNull' then
|
|
|
|
+ variableType = types.nonNull(variableType)
|
|
|
|
+ end
|
|
|
|
+
|
|
|
|
+ local function isTypeSubTypeOf(subType, superType)
|
|
|
|
+ if subType == superType then return true end
|
|
|
|
+
|
|
|
|
+ if superType.__type == 'NonNull' then
|
|
|
|
+ if subType.__type == 'NonNull' then
|
|
|
|
+ return isTypeSubTypeOf(subType.ofType, superType.ofType)
|
|
|
|
+ end
|
|
|
|
+
|
|
|
|
+ return false
|
|
|
|
+ elseif subType.__type == 'NonNull' then
|
|
|
|
+ return typeIsSubTypeOf(subType.ofType, superType)
|
|
|
|
+ end
|
|
|
|
+
|
|
|
|
+ if superType.__type == 'List' then
|
|
|
|
+ if subType.__type == 'List' then
|
|
|
|
+ return isTypeSubTypeOf(subType.ofType, superType.ofType)
|
|
|
|
+ end
|
|
|
|
+
|
|
|
|
+ return false
|
|
|
|
+ elseif subType.__type == 'List' then
|
|
|
|
+ return false
|
|
|
|
+ end
|
|
|
|
+
|
|
|
|
+ if subType.__type ~= 'Object' then return false end
|
|
|
|
+
|
|
|
|
+ if superType.__type == 'Interface' then
|
|
|
|
+ local implementors = context.schema:getImplementors(superType.name)
|
|
|
|
+ return implementors and implementors[context.schema:getType(subType.name)]
|
|
|
|
+ elseif superType.__type == 'Union' then
|
|
|
|
+ local types = superType.types
|
|
|
|
+ for i = 1, #types do
|
|
|
|
+ if types[i] == subType then
|
|
|
|
+ return true
|
|
|
|
+ end
|
|
|
|
+ end
|
|
|
|
+
|
|
|
|
+ return false
|
|
|
|
+ end
|
|
|
|
+
|
|
|
|
+ return false
|
|
|
|
+ end
|
|
|
|
+
|
|
|
|
+ if not isTypeSubTypeOf(variableType, argumentType) then
|
|
|
|
+ error('Variable type mismatch')
|
|
|
|
+ end
|
|
|
|
+ end
|
|
|
|
+ end
|
|
|
|
+ end
|
|
|
|
+ end
|
|
|
|
+end
|
|
|
|
+
|
|
return rules
|
|
return rules
|