Browse Source

Finish validation rules;

bjorn 9 years ago
parent
commit
176c8b05fa
2 changed files with 143 additions and 18 deletions
  1. 123 0
      rules.lua
  2. 20 18
      validate.lua

+ 123 - 0
rules.lua

@@ -1,3 +1,4 @@
+local types = require 'types'
 local rules = {}
 local rules = {}
 
 
 local function checkValueMatchesType(valueNode, schemaType)
 local function checkValueMatchesType(valueNode, schemaType)
@@ -453,4 +454,126 @@ function rules.variablesAreDefined(node, context)
   end
   end
 end
 end
 
 
+function rules.variableUsageAllowed(node, context)
+  if context.currentOperation then
+    local variableMap = {}
+    for _, definition in ipairs(context.currentOperation.variableDefinitions or {}) do
+      variableMap[definition.variable.name.value] = definition
+    end
+
+    local arguments
+
+    if node.kind == 'field' then
+      arguments = { [node.name.value] = node.arguments }
+    elseif node.kind == 'fragmentSpread' then
+      local function collectArguments(referencedNode)
+        if referencedNode.kind == 'selectionSet' then
+          for _, selection in ipairs(referencedNode.selections) do
+            collectArguments(selection)
+          end
+        elseif referencedNode.kind == 'field' and referencedNode.arguments then
+          local fieldName = referencedNode.name.value
+          arguments[fieldName] = arguments[fieldName] or {}
+          for _, argument in ipairs(referencedNode.arguments) do
+            table.insert(arguments[fieldName], argument)
+          end
+        elseif referencedNode.kind == 'inlineFragment' then
+          return collectArguments(referencedNode.selectionSet)
+        elseif referencedNode.kind == 'fragmentSpread' then
+          local fragment = context.fragmentMap[referencedNode.name.value]
+          return fragment and collectArguments(fragment.selectionSet)
+        end
+      end
+
+      local fragment = context.fragmentMap[node.name.value]
+      if fragment then
+        arguments = {}
+        collectArguments(fragment.selectionSet)
+      end
+    end
+
+    if not arguments then return end
+
+    for field in pairs(arguments) do
+      local parentField = context.objects[#context.objects - 1].fields[field]
+      for i = 1, #arguments[field] do
+        local argument = arguments[field][i]
+        if argument.value.kind == 'variable' then
+          local argumentType = parentField.arguments[argument.name.value]
+
+          local variableName = argument.value.name.value
+          local variableDefinition = variableMap[variableName]
+          local hasDefault = variableDefinition.defaultValue ~= nil
+
+          local function typeFromAST(variable)
+            local innerType
+            if variable.kind == 'listType' then
+              innerType = typeFromAST(variable.type)
+              return innerType and types.list(innerType)
+            elseif variable.kind == 'nonNullType' then
+              innerType = typeFromAST(variable.type)
+              return innerType and types.nonNull(innerType)
+            else
+              assert(variable.kind == 'namedType', 'Variable must be a named type')
+              return context.schema:getType(variable.name.value)
+            end
+          end
+
+          local variableType = typeFromAST(variableDefinition.type)
+
+          if hasDefault and variableType.__type ~= 'NonNull' then
+            variableType = types.nonNull(variableType)
+          end
+
+          local function isTypeSubTypeOf(subType, superType)
+            if subType == superType then return true end
+
+            if superType.__type == 'NonNull' then
+              if subType.__type == 'NonNull' then
+                return isTypeSubTypeOf(subType.ofType, superType.ofType)
+              end
+
+              return false
+            elseif subType.__type == 'NonNull' then
+              return typeIsSubTypeOf(subType.ofType, superType)
+            end
+
+            if superType.__type == 'List' then
+              if subType.__type == 'List' then
+                return isTypeSubTypeOf(subType.ofType, superType.ofType)
+              end
+
+              return false
+            elseif subType.__type == 'List' then
+              return false
+            end
+
+            if subType.__type ~= 'Object' then return false end
+
+            if superType.__type == 'Interface' then
+              local implementors = context.schema:getImplementors(superType.name)
+              return implementors and implementors[context.schema:getType(subType.name)]
+            elseif superType.__type == 'Union' then
+              local types = superType.types
+              for i = 1, #types do
+                if types[i] == subType then
+                  return true
+                end
+              end
+
+              return false
+            end
+
+            return false
+          end
+
+          if not isTypeSubTypeOf(variableType, argumentType) then
+            error('Variable type mismatch')
+          end
+        end
+      end
+    end
+  end
+end
+
 return rules
 return rules

+ 20 - 18
validate.lua

@@ -99,7 +99,8 @@ local visitors = {
       rules.uniqueArgumentNames,
       rules.uniqueArgumentNames,
       rules.argumentsOfCorrectType,
       rules.argumentsOfCorrectType,
       rules.requiredArgumentsPresent,
       rules.requiredArgumentsPresent,
-      rules.directivesAreDefined
+      rules.directivesAreDefined,
+      rules.variableUsageAllowed
     }
     }
   },
   },
 
 
@@ -141,27 +142,27 @@ local visitors = {
       table.insert(context.objects, fragmentType)
       table.insert(context.objects, fragmentType)
 
 
       if context.currentOperation then
       if context.currentOperation then
-        local function collectTransitiveVariables(node)
-          if not node then return end
+        local function collectTransitiveVariables(referencedNode)
+          if not referencedNode then return end
 
 
-          if node.kind == 'selectionSet' then
-            for _, selection in ipairs(node.selections) do
+          if referencedNode.kind == 'selectionSet' then
+            for _, selection in ipairs(referencedNode.selections) do
               collectTransitiveVariables(selection)
               collectTransitiveVariables(selection)
             end
             end
-          elseif node.kind == 'field' and node.arguments then
-            for _, argument in ipairs(node.arguments) do
+          elseif referencedNode.kind == 'field' and referencedNode.arguments then
+            for _, argument in ipairs(referencedNode.arguments) do
               collectTransitiveVariables(argument)
               collectTransitiveVariables(argument)
             end
             end
-          elseif node.kind == 'argument' then
-            return collectTransitiveVariables(node.value)
-          elseif node.kind == 'listType' or node.kind == 'nonNullType' then
-            return collectTransitiveVariables(node.type)
-          elseif node.kind == 'variable' then
-            context.variableReferences[node.name.value] = node.name.value
-          elseif node.kind == 'inlineFragment' then
-            return collectTransitiveVariables(node.selectionSet)
-          elseif node.kind == 'fragmentSpread' then
-            local fragment = context.fragmentMap[node.name.value]
+          elseif referencedNode.kind == 'argument' then
+            return collectTransitiveVariables(referencedNode.value)
+          elseif referencedNode.kind == 'listType' or referencedNode.kind == 'nonNullType' then
+            return collectTransitiveVariables(referencedNode.type)
+          elseif referencedNode.kind == 'variable' then
+            context.variableReferences[referencedNode.name.value] = true
+          elseif referencedNode.kind == 'inlineFragment' then
+            return collectTransitiveVariables(referencedNode.selectionSet)
+          elseif referencedNode.kind == 'fragmentSpread' then
+            local fragment = context.fragmentMap[referencedNode.name.value]
             return fragment and collectTransitiveVariables(fragment.selectionSet)
             return fragment and collectTransitiveVariables(fragment.selectionSet)
           end
           end
         end
         end
@@ -173,7 +174,8 @@ local visitors = {
     rules = {
     rules = {
       rules.fragmentSpreadTargetDefined,
       rules.fragmentSpreadTargetDefined,
       rules.fragmentSpreadIsPossible,
       rules.fragmentSpreadIsPossible,
-      rules.directivesAreDefined
+      rules.directivesAreDefined,
+      rules.variableUsageAllowed
     }
     }
   },
   },