Sfoglia il codice sorgente

Execution basically works;

bjorn 9 anni fa
parent
commit
9b61599cd3
3 ha cambiato i file con 75 aggiunte e 32 eliminazioni
  1. 71 28
      execute.lua
  2. 1 1
      rules.lua
  3. 3 3
      util.lua

+ 71 - 28
execute.lua

@@ -34,7 +34,7 @@ local function shouldIncludeNode(selection, context)
 
 
       if not ifArgument then return end
       if not ifArgument then return end
 
 
-      return util.coerceValue(ifArgument.value, _type.arguments['if'])
+      return util.coerceValue(ifArgument.value, _type.arguments['if'], context.variables)
     end
     end
 
 
     if isDirectiveActive('skip', types.skip) then return false end
     if isDirectiveActive('skip', types.skip) then return false end
@@ -52,27 +52,28 @@ local function doesFragmentApply(fragment, type, context)
   if innerType == type then
   if innerType == type then
     return true
     return true
   elseif innerType.__type == 'Interface' then
   elseif innerType.__type == 'Interface' then
-    return schema:getImplementors(type)[innerType]
+    local implementors = context.schema:getImplementors(innerType.name)
+    return implementors and implementors[type]
   elseif innerType.__type == 'Union' then
   elseif innerType.__type == 'Union' then
-    return util.find(type.types, function(member)
-      return member == innerType
+    return util.find(innerType.types, function(member)
+      return member == type
     end)
     end)
   end
   end
 end
 end
 
 
 local function mergeSelectionSets(fields)
 local function mergeSelectionSets(fields)
-  local selectionSet = {}
+  local selections = {}
 
 
   for i = 1, #fields do
   for i = 1, #fields do
     local selectionSet = fields[i].selectionSet
     local selectionSet = fields[i].selectionSet
     if selectionSet then
     if selectionSet then
       for j = 1, #selectionSet.selections do
       for j = 1, #selectionSet.selections do
-        table.insert(selectionSet, selectionSet.selections[j])
+        table.insert(selections, selectionSet.selections[j])
       end
       end
     end
     end
   end
   end
 
 
-  return selectionSet
+  return selections
 end
 end
 
 
 local function defaultResolver(object, fields, info)
 local function defaultResolver(object, fields, info)
@@ -112,25 +113,25 @@ local function buildContext(schema, tree, variables, operationName)
   return context
   return context
 end
 end
 
 
-local function collectFields(objectType, selectionSet, visitedFragments, result, context)
-  for _, selection in ipairs(selectionSet.selections) do
+local function collectFields(objectType, selections, visitedFragments, result, context)
+  for _, selection in ipairs(selections) do
     if selection.kind == 'field' then
     if selection.kind == 'field' then
-      if shouldIncludeNode(selection) then
+      if shouldIncludeNode(selection, context) then
         local name = getFieldResponseKey(selection)
         local name = getFieldResponseKey(selection)
         result[name] = result[name] or {}
         result[name] = result[name] or {}
         table.insert(result[name], selection)
         table.insert(result[name], selection)
       end
       end
     elseif selection.kind == 'inlineFragment' then
     elseif selection.kind == 'inlineFragment' then
-      if shouldIncludeNode(selection) and doesFragmentApply(selection, objectType, context) then
-        collectFields(objectType, selection.selectionSet, visitedFragments, result, context)
+      if shouldIncludeNode(selection, context) and doesFragmentApply(selection, objectType, context) then
+        collectFields(objectType, selection.selectionSet.selections, visitedFragments, result, context)
       end
       end
     elseif selection.kind == 'fragmentSpread' then
     elseif selection.kind == 'fragmentSpread' then
       local fragmentName = selection.name.value
       local fragmentName = selection.name.value
-      if shouldIncludeNode(selection) and not visitedFragments[fragmentName] then
+      if shouldIncludeNode(selection, context) and not visitedFragments[fragmentName] then
         visitedFragments[fragmentName] = true
         visitedFragments[fragmentName] = true
         local fragment = context.fragmentMap[fragmentName]
         local fragment = context.fragmentMap[fragmentName]
-        if fragment and shouldIncludeNode(fragment) and doesFragmentApply(fragment, objectType, context) then
-          collectFields(objectType, fragment.selectionSet, visitedFragments, result, context)
+        if fragment and shouldIncludeNode(fragment, context) and doesFragmentApply(fragment, objectType, context) then
+          collectFields(objectType, fragment.selectionSet.selections, visitedFragments, result, context)
         end
         end
       end
       end
     end
     end
@@ -139,11 +140,57 @@ local function collectFields(objectType, selectionSet, visitedFragments, result,
   return result
   return result
 end
 end
 
 
-local function completeValue(fieldType, result, subSelectionSet)
-  return result -- TODO
+local evaluateSelections
+
+local function completeValue(fieldType, result, subSelections, context)
+  local fieldTypeName = fieldType.__type
+
+  if fieldTypeName == 'NonNull' then
+    local innerType = fieldType.ofType
+    local completedResult = completeValue(innerType, result, context)
+
+    if not completedResult then
+      error('No value provided for non-null ' .. innerType.name)
+    end
+
+    return completedResult
+  end
+
+  if not result then
+    return nil
+  end
+
+  if fieldTypeName == 'List' then
+    local innerType = fieldType.ofType
+
+    if type(result) ~= 'table' then
+      error('Expected a table for ' .. innerType.name .. ' list')
+    end
+
+    local values = {}
+
+    for i, value in ipairs(values) do
+      values[i] = completeValue(innerType, value, context)
+    end
+
+    return values
+  end
+
+  if fieldTypeName == 'Scalar' or fieldTypeName == 'Enum' then
+    return fieldType.serialize(result)
+  end
+
+  if fieldTypeName == 'Object' then
+    return evaluateSelections(fieldType, result, subSelections, context)
+  elseif fieldTypeName == 'Interface' or fieldTypeName == 'Union' then
+    local objectType = fieldType.resolveType(result)
+    return evaluateSelections(objectType, result, subSelections, context)
+  end
+
+  error('Unknown type "' .. fieldTypeName .. '" for field "' .. field.name .. '"')
 end
 end
 
 
-local function getFieldEntry(objectType, object, fields)
+local function getFieldEntry(objectType, object, fields, context)
   local firstField = fields[1]
   local firstField = fields[1]
   local responseKey = getFieldResponseKey(firstField)
   local responseKey = getFieldResponseKey(firstField)
   local fieldType = objectType.fields[firstField.name.value]
   local fieldType = objectType.fields[firstField.name.value]
@@ -155,20 +202,16 @@ local function getFieldEntry(objectType, object, fields)
   -- TODO correct arguments to resolve
   -- TODO correct arguments to resolve
   local resolvedObject = (fieldType.resolve or defaultResolver)(object, fields, {})
   local resolvedObject = (fieldType.resolve or defaultResolver)(object, fields, {})
 
 
-  if not resolvedObject then
-    return nil -- TODO null
-  end
-
-  local subSelectionSet = mergeSelectionSets(fields)
-  local responseValue = completeValue(fieldType, resolvedObject, subSelectionSet)
+  local subSelections = mergeSelectionSets(fields)
+  local responseValue = completeValue(fieldType.kind, resolvedObject, subSelections, context)
   return responseValue
   return responseValue
 end
 end
 
 
-local function evaluateSelectionSet(objectType, object, selectionSet, context)
-  local groupedFieldSet = collectFields(objectType, selectionSet, {}, {}, context)
+evaluateSelections = function(objectType, object, selections, context)
+  local groupedFieldSet = collectFields(objectType, selections, {}, {}, context)
 
 
   return util.map(groupedFieldSet, function(fields)
   return util.map(groupedFieldSet, function(fields)
-    return getFieldEntry(objectType, object, fields)
+    return getFieldEntry(objectType, object, fields, context)
   end)
   end)
 end
 end
 
 
@@ -180,5 +223,5 @@ return function(schema, tree, variables, operationName, rootValue)
     error('Unsupported operation "' .. context.operation.operation .. '"')
     error('Unsupported operation "' .. context.operation.operation .. '"')
   end
   end
 
 
-  return evaluateSelectionSet(rootType, rootValue, context.operation.selectionSet, context)
+  return evaluateSelections(rootType, rootValue, context.operation.selectionSet.selections, context)
 end
 end

+ 1 - 1
rules.lua

@@ -128,7 +128,7 @@ function rules.unambiguousSelections(node, context)
   local function validateSelectionSet(selectionSet, parentType)
   local function validateSelectionSet(selectionSet, parentType)
     for _, selection in ipairs(selectionSet.selections) do
     for _, selection in ipairs(selectionSet.selections) do
       if selection.kind == 'field' then
       if selection.kind == 'field' then
-        if not parentType or not parentType.fields[selection.name.value] then return end
+        if not parentType or not parentType.fields or not parentType.fields[selection.name.value] then return end
 
 
         local key = selection.alias and selection.alias.name.value or selection.name.value
         local key = selection.alias and selection.alias.name.value or selection.name.value
         local definition = parentType.fields[selection.name.value].kind
         local definition = parentType.fields[selection.name.value].kind

+ 3 - 3
util.lua

@@ -17,7 +17,7 @@ function util.coerceValue(node, schemaType, variables)
   variables = variables or {}
   variables = variables or {}
 
 
   if schemaType.__type == 'NonNull' then
   if schemaType.__type == 'NonNull' then
-    return util.coerceValue(node, schemaType.ofType)
+    return util.coerceValue(node, schemaType.ofType, variables)
   end
   end
 
 
   if not node then
   if not node then
@@ -34,7 +34,7 @@ function util.coerceValue(node, schemaType, variables)
     end
     end
 
 
     return util.map(node.values, function(value)
     return util.map(node.values, function(value)
-      return util.coerceValue(node.values[i], schemaType.ofType)
+      return util.coerceValue(node.values[i], schemaType.ofType, variables)
     end)
     end)
   end
   end
 
 
@@ -48,7 +48,7 @@ function util.coerceValue(node, schemaType, variables)
         error('Unknown input object field "' .. field.name .. '"')
         error('Unknown input object field "' .. field.name .. '"')
       end
       end
 
 
-      return util.coerceValue(schemaType.fields[field.name].kind, field.value)
+      return util.coerceValue(schemaType.fields[field.name].kind, field.value, variables)
     end)
     end)
   end
   end