Browse Source

Merge pull request #5 from ruslantalpa/fix_get_parent_field

Fix get parent field
Bjorn Swenson 9 years ago
parent
commit
5235df4b8b
3 changed files with 20 additions and 18 deletions
  1. 7 10
      graphql/rules.lua
  2. 11 1
      graphql/util.lua
  3. 2 7
      graphql/validate.lua

+ 7 - 10
graphql/rules.lua

@@ -31,13 +31,16 @@ end
 function rules.fieldsDefinedOnType(node, context)
 function rules.fieldsDefinedOnType(node, context)
   if context.objects[#context.objects] == false then
   if context.objects[#context.objects] == false then
     local parent = context.objects[#context.objects - 1]
     local parent = context.objects[#context.objects - 1]
+    if(parent.__type == 'List') then
+      parent = parent.ofType
+    end
     error('Field "' .. node.name.value .. '" is not defined on type "' .. parent.name .. '"')
     error('Field "' .. node.name.value .. '" is not defined on type "' .. parent.name .. '"')
   end
   end
 end
 end
 
 
 function rules.argumentsDefinedOnType(node, context)
 function rules.argumentsDefinedOnType(node, context)
   if node.arguments then
   if node.arguments then
-    local parentField = context.objects[#context.objects - 1].fields[node.name.value]
+    local parentField = util.getParentField(context, node.name.value)
     for _, argument in pairs(node.arguments) do
     for _, argument in pairs(node.arguments) do
       local name = argument.name.value
       local name = argument.name.value
       if not parentField.arguments[name] then
       if not parentField.arguments[name] then
@@ -175,7 +178,7 @@ end
 
 
 function rules.argumentsOfCorrectType(node, context)
 function rules.argumentsOfCorrectType(node, context)
   if node.arguments then
   if node.arguments then
-    local parentField = context.objects[#context.objects - 1].fields[node.name.value]
+    local parentField = util.getParentField(context, node.name.value)
     for _, argument in pairs(node.arguments) do
     for _, argument in pairs(node.arguments) do
       local name = argument.name.value
       local name = argument.name.value
       local argumentType = parentField.arguments[name]
       local argumentType = parentField.arguments[name]
@@ -186,13 +189,7 @@ end
 
 
 function rules.requiredArgumentsPresent(node, context)
 function rules.requiredArgumentsPresent(node, context)
   local arguments = node.arguments or {}
   local arguments = node.arguments or {}
-  local parentField
-  if context.objects[#context.objects - 1].__type == 'List' then
-    parentField = context.objects[#context.objects - 2].fields[node.name.value]
-  else
-    parentField = context.objects[#context.objects - 1].fields[node.name.value]
-  end
-
+  local parentField = util.getParentField(context, node.name.value)
   for name, argument in pairs(parentField.arguments) do
   for name, argument in pairs(parentField.arguments) do
     if argument.__type == 'NonNull' then
     if argument.__type == 'NonNull' then
       local present = util.find(arguments, function(argument)
       local present = util.find(arguments, function(argument)
@@ -451,7 +448,7 @@ function rules.variableUsageAllowed(node, context)
     if not arguments then return end
     if not arguments then return end
 
 
     for field in pairs(arguments) do
     for field in pairs(arguments) do
-      local parentField = context.objects[#context.objects - 1].fields[field]
+      local parentField = util.getParentField(context, field)
       for i = 1, #arguments[field] do
       for i = 1, #arguments[field] do
         local argument = arguments[field][i]
         local argument = arguments[field][i]
         if argument.value.kind == 'variable' then
         if argument.value.kind == 'variable' then

+ 11 - 1
graphql/util.lua

@@ -23,6 +23,16 @@ function util.bind1(func, x)
   end
   end
 end
 end
 
 
+function util.getParentField(context, name, count)
+  count = count == nil and 1 or count
+  local obj = context.objects[#context.objects - count]
+  if obj.__type == 'List' then
+    return obj.ofType.fields[name]
+  else
+    return obj.fields[name]
+  end
+end
+
 function util.coerceValue(node, schemaType, variables)
 function util.coerceValue(node, schemaType, variables)
   variables = variables or {}
   variables = variables or {}
 
 
@@ -58,7 +68,7 @@ function util.coerceValue(node, schemaType, variables)
         error('Unknown input object field "' .. field.name .. '"')
         error('Unknown input object field "' .. field.name .. '"')
       end
       end
 
 
-      return util.coerceValue(schemaType.fields[field.name].kind, field.value, variables)
+      return util.coerceValue(field.value, schemaType.fields[field.name].kind, variables)
     end)
     end)
   end
   end
 
 

+ 2 - 7
graphql/validate.lua

@@ -1,5 +1,6 @@
 local path = (...):gsub('%.[^%.]+$', '')
 local path = (...):gsub('%.[^%.]+$', '')
 local rules = require(path .. '.rules')
 local rules = require(path .. '.rules')
+local util = require(path .. '.util')
 
 
 local visitors = {
 local visitors = {
   document = {
   document = {
@@ -58,16 +59,10 @@ local visitors = {
 
 
   field = {
   field = {
     enter = function(node, context)
     enter = function(node, context)
-      local parentField
-      if context.objects[#context.objects].__type == 'List' then
-        parentField = context.objects[#context.objects - 1].fields[node.name.value]
-      else
-        parentField = context.objects[#context.objects].fields[node.name.value]
-      end
+      local parentField = util.getParentField(context, node.name.value, 0)
 
 
       -- false is a special value indicating that the field was not present in the type definition.
       -- false is a special value indicating that the field was not present in the type definition.
       local field = parentField and parentField.kind or false
       local field = parentField and parentField.kind or false
-
       table.insert(context.objects, field)
       table.insert(context.objects, field)
     end,
     end,