2
0
bjorn 9 жил өмнө
parent
commit
6154449287
1 өөрчлөгдсөн 93 нэмэгдсэн , 20 устгасан
  1. 93 20
      execute.lua

+ 93 - 20
execute.lua

@@ -1,5 +1,23 @@
-local function getArgumentValue(value, context)
-  return nil -- TODO
+local types = require 'types'
+local util = require 'util'
+
+local function typeFromAST(node, schema)
+  local innerType
+  if node.kind == 'listType' then
+    innerType = typeFromAST(node.type)
+    return innerType and types.list(innerType)
+  elseif node.kind == 'nonNullType' then
+    innerType = typeFromAST(node.type)
+    return innerType and types.nonNull(innerType)
+  else
+    assert(node.kind == 'namedType', 'Variable must be a named type')
+    return schema:getType(node.name.value)
+  end
+end
+
+local function defaultResolver(source, arguments, info)
+  local property = source[info.fieldName]
+  return type(property) == 'function' and property(source) or property
 end
 
 local function getFieldEntryKey(selection)
@@ -8,26 +26,45 @@ end
 
 local function shouldIncludeNode(selection, context)
   if selection.directives then
-    for _, directive in ipairs(selection.directives) do
-      if directive.name.value == 'skip' then
-        for _, argument in ipairs(directive.arguments) do
-          if argument.name == 'if' and getArgumentValue(argument.value, context) then
-            return false
-          end
-        end
-      elseif directive.name.value == 'include' then
-        for _, argument in ipairs(directive.arguments) do
-          if argument.name == 'if' and not getArgumentValue(argument.value, context) then
-            return false
-          end
-        end
-      end
+    local function isDirectiveActive(key, _type)
+      local directive = util.find(selection.directives, function(directive)
+        return directive.name.value == key
+      end)
+
+      if not directive then return end
+
+      local ifArgument = util.find(directive.arguments, function(argument)
+        return argument.name.value == 'if'
+      end)
+
+      if not ifArgument then return end
+
+      return util.coerceValue(ifArgument.value, _type.arguments['if'])
     end
+
+    if isDirectiveActive('skip', types.skip) then return false end
+    if isDirectiveActive('include', types.include) == false then return false end
   end
 
   return true
 end
 
+local function doesFragmentApply(fragment, type, context)
+  if not fragment.typeCondition then return true end
+
+  local innerType = typeFromAST(fragment.typeCondition, context.schema)
+
+  if innerType == type then
+    return true
+  elseif innerType.__type == 'Interface' then
+    return schema:getImplementors(type)[innerType]
+  elseif innerType.__type == 'Union' then
+    return util.find(type.types, function(member)
+      return member == innerType
+    end)
+  end
+end
+
 local function collectFields(selectionSet, type, fields, visitedFragments, context)
   for _, selection in ipairs(selectionSet.selections) do
     if selection.kind == 'field' then
@@ -37,7 +74,7 @@ local function collectFields(selectionSet, type, fields, visitedFragments, conte
         table.insert(fields[name], selection)
       end
     elseif selection.kind == 'inlineFragment' then
-      if shouldIncludeNode(selection) and doesFragmentApply(selection, type) then
+      if shouldIncludeNode(selection) and doesFragmentApply(selection, type, context) then
         collectFields(selection.selectionSet, type, fields, visitedFragments, context)
       end
     elseif selection.kind == 'fragmentSpread' then
@@ -45,7 +82,7 @@ local function collectFields(selectionSet, type, fields, visitedFragments, conte
       if shouldIncludeNode(selection) and not visitedFragments[fragmentName] then
         visitedFragments[fragmentName] = true
         local fragment = context.fragmentMap[fragmentName]
-        if fragment and shouldIncludeNode(fragment) and doesFragmentApply(fragment, type) then
+        if fragment and shouldIncludeNode(fragment) and doesFragmentApply(fragment, type, context) then
           collectFields(fragment.selectionSet, type, fields, visitedFragments, context)
         end
       end
@@ -58,6 +95,7 @@ end
 local function buildContext(schema, tree, variables, operationName)
   local context = {
     schema = schema,
+    variables = variables,
     operation = nil,
     fragmentMap = {}
   }
@@ -87,7 +125,41 @@ local function buildContext(schema, tree, variables, operationName)
   return context
 end
 
-return function(schema, tree, variables, operationName)
+local function executeFields(parentType, rootValue, fieldGroups, context)
+  local result = {}
+
+  for name, fieldGroup in pairs(fieldGroups) do
+    result[name] = resolveField(parentType, rootValue, fieldGroup, context)
+  end
+
+  return result
+end
+
+local function resolveField(parentType, rootValue, fields, context)
+  local field = fields[1]
+  local fieldName = field.name.value
+
+  local fieldType = parentType.fields[fieldName]
+  local returnType = fieldType.kind
+
+  local info = {
+    fieldName = fieldName,
+    fields = fields,
+    returnType = returnType,
+    parentType = parentType,
+    schema = context.schema,
+    fragments = context.fragmentMap,
+    rootValue = rootValue,
+    operation = context.operation,
+    variables = context.variables
+  }
+
+  local resolve = fieldType.resolve or defaultResolver
+
+  local result = resolve(source, {}, info)
+end
+
+return function(schema, tree, rootValue, variables, operationName)
   local context = buildContext(schema, tree, variables, operationName)
   local rootType = schema[context.operation.operation]
 
@@ -95,5 +167,6 @@ return function(schema, tree, variables, operationName)
     error('Unsupported operation "' .. context.operation.operation .. '"')
   end
 
-  local fields = collectFields(context.operation.selectionSet, rootType, {}, {}, context)
+  local fieldGroups = collectFields(context.operation.selectionSet, rootType, {}, {}, context)
+  return executeFields(rootType, rootValue, fieldGroups, context)
 end