浏览代码

Fix accessing parent field in multiple places

Ruslan Talpa 9 年之前
父节点
当前提交
45d834bea1
共有 3 个文件被更改,包括 21 次插入17 次删除
  1. 10 10
      graphql/rules.lua
  2. 9 0
      graphql/util.lua
  3. 2 7
      graphql/validate.lua

+ 10 - 10
graphql/rules.lua

@@ -31,13 +31,16 @@ end
 function rules.fieldsDefinedOnType(node, context)
 function rules.fieldsDefinedOnType(node, context)
   if context.objects[#context.objects] == false then
   if context.objects[#context.objects] == false then
     local parent = context.objects[#context.objects - 1]
     local parent = context.objects[#context.objects - 1]
+    if(parent.__type == 'List') then
+      parent = parent.ofType
+    end
     error('Field "' .. node.name.value .. '" is not defined on type "' .. parent.name .. '"')
     error('Field "' .. node.name.value .. '" is not defined on type "' .. parent.name .. '"')
   end
   end
 end
 end
 
 
 function rules.argumentsDefinedOnType(node, context)
 function rules.argumentsDefinedOnType(node, context)
   if node.arguments then
   if node.arguments then
-    local parentField = context.objects[#context.objects - 1].fields[node.name.value]
+    local parentField = util.getParentField(context, node.name.value, 1)
     for _, argument in pairs(node.arguments) do
     for _, argument in pairs(node.arguments) do
       local name = argument.name.value
       local name = argument.name.value
       if not parentField.arguments[name] then
       if not parentField.arguments[name] then
@@ -175,7 +178,7 @@ end
 
 
 function rules.argumentsOfCorrectType(node, context)
 function rules.argumentsOfCorrectType(node, context)
   if node.arguments then
   if node.arguments then
-    local parentField = context.objects[#context.objects - 1].fields[node.name.value]
+    local parentField = util.getParentField(context, node.name.value, 1)
     for _, argument in pairs(node.arguments) do
     for _, argument in pairs(node.arguments) do
       local name = argument.name.value
       local name = argument.name.value
       local argumentType = parentField.arguments[name]
       local argumentType = parentField.arguments[name]
@@ -186,13 +189,7 @@ end
 
 
 function rules.requiredArgumentsPresent(node, context)
 function rules.requiredArgumentsPresent(node, context)
   local arguments = node.arguments or {}
   local arguments = node.arguments or {}
-  local parentField
-  if context.objects[#context.objects - 1].__type == 'List' then
-    parentField = context.objects[#context.objects - 2].fields[node.name.value]
-  else
-    parentField = context.objects[#context.objects - 1].fields[node.name.value]
-  end
-
+  local parentField = util.getParentField(context, node.name.value, 1)
   for name, argument in pairs(parentField.arguments) do
   for name, argument in pairs(parentField.arguments) do
     if argument.__type == 'NonNull' then
     if argument.__type == 'NonNull' then
       local present = util.find(arguments, function(argument)
       local present = util.find(arguments, function(argument)
@@ -279,6 +276,9 @@ end
 function rules.fragmentSpreadIsPossible(node, context)
 function rules.fragmentSpreadIsPossible(node, context)
   local fragment = node.kind == 'inlineFragment' and node or context.fragmentMap[node.name.value]
   local fragment = node.kind == 'inlineFragment' and node or context.fragmentMap[node.name.value]
   local parentType = context.objects[#context.objects - 1]
   local parentType = context.objects[#context.objects - 1]
+  if(parent.__type == 'List') then
+      parent = parent.ofType
+  end
 
 
   local fragmentType
   local fragmentType
   if node.kind == 'inlineFragment' then
   if node.kind == 'inlineFragment' then
@@ -451,7 +451,7 @@ function rules.variableUsageAllowed(node, context)
     if not arguments then return end
     if not arguments then return end
 
 
     for field in pairs(arguments) do
     for field in pairs(arguments) do
-      local parentField = context.objects[#context.objects - 1].fields[field]
+      local parentField = util.getParentField(context, field, 1)
       for i = 1, #arguments[field] do
       for i = 1, #arguments[field] do
         local argument = arguments[field][i]
         local argument = arguments[field][i]
         if argument.value.kind == 'variable' then
         if argument.value.kind == 'variable' then

+ 9 - 0
graphql/util.lua

@@ -23,6 +23,15 @@ function util.bind1(func, x)
   end
   end
 end
 end
 
 
+function util.getParentField(context, name, step_back)
+  local obj = context.objects[#context.objects - step_back]
+    if obj.__type == 'List' then
+      return obj.ofType.fields[name]
+    else
+      return obj.fields[name]
+    end
+end
+
 function util.coerceValue(node, schemaType, variables)
 function util.coerceValue(node, schemaType, variables)
   variables = variables or {}
   variables = variables or {}
 
 

+ 2 - 7
graphql/validate.lua

@@ -1,5 +1,6 @@
 local path = (...):gsub('%.[^%.]+$', '')
 local path = (...):gsub('%.[^%.]+$', '')
 local rules = require(path .. '.rules')
 local rules = require(path .. '.rules')
+local util = require(path .. '.util')
 
 
 local visitors = {
 local visitors = {
   document = {
   document = {
@@ -58,16 +59,10 @@ local visitors = {
 
 
   field = {
   field = {
     enter = function(node, context)
     enter = function(node, context)
-      local parentField
-      if context.objects[#context.objects].__type == 'List' then
-        parentField = context.objects[#context.objects - 1].fields[node.name.value]
-      else
-        parentField = context.objects[#context.objects].fields[node.name.value]
-      end
+      local parentField = util.getParentField(context, node.name.value, 0)
 
 
       -- false is a special value indicating that the field was not present in the type definition.
       -- false is a special value indicating that the field was not present in the type definition.
       local field = parentField and parentField.kind or false
       local field = parentField and parentField.kind or false
-
       table.insert(context.objects, field)
       table.insert(context.objects, field)
     end,
     end,