Explorar o código

Fix accessing parent field in multiple places

Ruslan Talpa %!s(int64=9) %!d(string=hai) anos
pai
achega
45d834bea1
Modificáronse 3 ficheiros con 21 adicións e 17 borrados
  1. 10 10
      graphql/rules.lua
  2. 9 0
      graphql/util.lua
  3. 2 7
      graphql/validate.lua

+ 10 - 10
graphql/rules.lua

@@ -31,13 +31,16 @@ end
 function rules.fieldsDefinedOnType(node, context)
   if context.objects[#context.objects] == false then
     local parent = context.objects[#context.objects - 1]
+    if(parent.__type == 'List') then
+      parent = parent.ofType
+    end
     error('Field "' .. node.name.value .. '" is not defined on type "' .. parent.name .. '"')
   end
 end
 
 function rules.argumentsDefinedOnType(node, context)
   if node.arguments then
-    local parentField = context.objects[#context.objects - 1].fields[node.name.value]
+    local parentField = util.getParentField(context, node.name.value, 1)
     for _, argument in pairs(node.arguments) do
       local name = argument.name.value
       if not parentField.arguments[name] then
@@ -175,7 +178,7 @@ end
 
 function rules.argumentsOfCorrectType(node, context)
   if node.arguments then
-    local parentField = context.objects[#context.objects - 1].fields[node.name.value]
+    local parentField = util.getParentField(context, node.name.value, 1)
     for _, argument in pairs(node.arguments) do
       local name = argument.name.value
       local argumentType = parentField.arguments[name]
@@ -186,13 +189,7 @@ end
 
 function rules.requiredArgumentsPresent(node, context)
   local arguments = node.arguments or {}
-  local parentField
-  if context.objects[#context.objects - 1].__type == 'List' then
-    parentField = context.objects[#context.objects - 2].fields[node.name.value]
-  else
-    parentField = context.objects[#context.objects - 1].fields[node.name.value]
-  end
-
+  local parentField = util.getParentField(context, node.name.value, 1)
   for name, argument in pairs(parentField.arguments) do
     if argument.__type == 'NonNull' then
       local present = util.find(arguments, function(argument)
@@ -279,6 +276,9 @@ end
 function rules.fragmentSpreadIsPossible(node, context)
   local fragment = node.kind == 'inlineFragment' and node or context.fragmentMap[node.name.value]
   local parentType = context.objects[#context.objects - 1]
+  if(parent.__type == 'List') then
+      parent = parent.ofType
+  end
 
   local fragmentType
   if node.kind == 'inlineFragment' then
@@ -451,7 +451,7 @@ function rules.variableUsageAllowed(node, context)
     if not arguments then return end
 
     for field in pairs(arguments) do
-      local parentField = context.objects[#context.objects - 1].fields[field]
+      local parentField = util.getParentField(context, field, 1)
       for i = 1, #arguments[field] do
         local argument = arguments[field][i]
         if argument.value.kind == 'variable' then

+ 9 - 0
graphql/util.lua

@@ -23,6 +23,15 @@ function util.bind1(func, x)
   end
 end
 
+function util.getParentField(context, name, step_back)
+  local obj = context.objects[#context.objects - step_back]
+    if obj.__type == 'List' then
+      return obj.ofType.fields[name]
+    else
+      return obj.fields[name]
+    end
+end
+
 function util.coerceValue(node, schemaType, variables)
   variables = variables or {}
 

+ 2 - 7
graphql/validate.lua

@@ -1,5 +1,6 @@
 local path = (...):gsub('%.[^%.]+$', '')
 local rules = require(path .. '.rules')
+local util = require(path .. '.util')
 
 local visitors = {
   document = {
@@ -58,16 +59,10 @@ local visitors = {
 
   field = {
     enter = function(node, context)
-      local parentField
-      if context.objects[#context.objects].__type == 'List' then
-        parentField = context.objects[#context.objects - 1].fields[node.name.value]
-      else
-        parentField = context.objects[#context.objects].fields[node.name.value]
-      end
+      local parentField = util.getParentField(context, node.name.value, 0)
 
       -- false is a special value indicating that the field was not present in the type definition.
       local field = parentField and parentField.kind or false
-
       table.insert(context.objects, field)
     end,