execute.lua 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. local types = require 'types'
  2. local util = require 'util'
  3. local function typeFromAST(node, schema)
  4. local innerType
  5. if node.kind == 'listType' then
  6. innerType = typeFromAST(node.type)
  7. return innerType and types.list(innerType)
  8. elseif node.kind == 'nonNullType' then
  9. innerType = typeFromAST(node.type)
  10. return innerType and types.nonNull(innerType)
  11. else
  12. assert(node.kind == 'namedType', 'Variable must be a named type')
  13. return schema:getType(node.name.value)
  14. end
  15. end
  16. local function getFieldResponseKey(field)
  17. return field.alias and field.alias.name.value or field.name.value
  18. end
  19. local function shouldIncludeNode(selection, context)
  20. if selection.directives then
  21. local function isDirectiveActive(key, _type)
  22. local directive = util.find(selection.directives, function(directive)
  23. return directive.name.value == key
  24. end)
  25. if not directive then return end
  26. local ifArgument = util.find(directive.arguments, function(argument)
  27. return argument.name.value == 'if'
  28. end)
  29. if not ifArgument then return end
  30. return util.coerceValue(ifArgument.value, _type.arguments['if'], context.variables)
  31. end
  32. if isDirectiveActive('skip', types.skip) then return false end
  33. if isDirectiveActive('include', types.include) == false then return false end
  34. end
  35. return true
  36. end
  37. local function doesFragmentApply(fragment, type, context)
  38. if not fragment.typeCondition then return true end
  39. local innerType = typeFromAST(fragment.typeCondition, context.schema)
  40. if innerType == type then
  41. return true
  42. elseif innerType.__type == 'Interface' then
  43. local implementors = context.schema:getImplementors(innerType.name)
  44. return implementors and implementors[type]
  45. elseif innerType.__type == 'Union' then
  46. return util.find(innerType.types, function(member)
  47. return member == type
  48. end)
  49. end
  50. end
  51. local function mergeSelectionSets(fields)
  52. local selections = {}
  53. for i = 1, #fields do
  54. local selectionSet = fields[i].selectionSet
  55. if selectionSet then
  56. for j = 1, #selectionSet.selections do
  57. table.insert(selections, selectionSet.selections[j])
  58. end
  59. end
  60. end
  61. return selections
  62. end
  63. local function defaultResolver(object, fields, info)
  64. return object[fields[1].name.value]
  65. end
  66. local function buildContext(schema, tree, variables, operationName)
  67. local context = {
  68. schema = schema,
  69. variables = variables,
  70. operation = nil,
  71. fragmentMap = {}
  72. }
  73. for _, definition in ipairs(tree.definitions) do
  74. if definition.kind == 'operation' then
  75. if not operationName and context.operation then
  76. error('Operation name must be specified if more than one operation exists.')
  77. end
  78. if not operationName or definition.name.value == operationName then
  79. context.operation = definition
  80. end
  81. elseif definition.kind == 'fragmentDefinition' then
  82. context.fragmentMap[definition.name.value] = definition
  83. end
  84. end
  85. if not context.operation then
  86. if operationName then
  87. error('Unknown operation "' .. operationName .. '"')
  88. else
  89. error('Must provide an operation')
  90. end
  91. end
  92. return context
  93. end
  94. local function collectFields(objectType, selections, visitedFragments, result, context)
  95. for _, selection in ipairs(selections) do
  96. if selection.kind == 'field' then
  97. if shouldIncludeNode(selection, context) then
  98. local name = getFieldResponseKey(selection)
  99. result[name] = result[name] or {}
  100. table.insert(result[name], selection)
  101. end
  102. elseif selection.kind == 'inlineFragment' then
  103. if shouldIncludeNode(selection, context) and doesFragmentApply(selection, objectType, context) then
  104. collectFields(objectType, selection.selectionSet.selections, visitedFragments, result, context)
  105. end
  106. elseif selection.kind == 'fragmentSpread' then
  107. local fragmentName = selection.name.value
  108. if shouldIncludeNode(selection, context) and not visitedFragments[fragmentName] then
  109. visitedFragments[fragmentName] = true
  110. local fragment = context.fragmentMap[fragmentName]
  111. if fragment and shouldIncludeNode(fragment, context) and doesFragmentApply(fragment, objectType, context) then
  112. collectFields(objectType, fragment.selectionSet.selections, visitedFragments, result, context)
  113. end
  114. end
  115. end
  116. end
  117. return result
  118. end
  119. local evaluateSelections
  120. local function completeValue(fieldType, result, subSelections, context)
  121. local fieldTypeName = fieldType.__type
  122. if fieldTypeName == 'NonNull' then
  123. local innerType = fieldType.ofType
  124. local completedResult = completeValue(innerType, result, context)
  125. if not completedResult then
  126. error('No value provided for non-null ' .. innerType.name)
  127. end
  128. return completedResult
  129. end
  130. if not result then
  131. return nil
  132. end
  133. if fieldTypeName == 'List' then
  134. local innerType = fieldType.ofType
  135. if type(result) ~= 'table' then
  136. error('Expected a table for ' .. innerType.name .. ' list')
  137. end
  138. local values = {}
  139. for i, value in ipairs(values) do
  140. values[i] = completeValue(innerType, value, context)
  141. end
  142. return values
  143. end
  144. if fieldTypeName == 'Scalar' or fieldTypeName == 'Enum' then
  145. return fieldType.serialize(result)
  146. end
  147. if fieldTypeName == 'Object' then
  148. return evaluateSelections(fieldType, result, subSelections, context)
  149. elseif fieldTypeName == 'Interface' or fieldTypeName == 'Union' then
  150. local objectType = fieldType.resolveType(result)
  151. return evaluateSelections(objectType, result, subSelections, context)
  152. end
  153. error('Unknown type "' .. fieldTypeName .. '" for field "' .. field.name .. '"')
  154. end
  155. local function getFieldEntry(objectType, object, fields, context)
  156. local firstField = fields[1]
  157. local responseKey = getFieldResponseKey(firstField)
  158. local fieldType = objectType.fields[firstField.name.value]
  159. if fieldType == nil then
  160. return nil
  161. end
  162. -- TODO correct arguments to resolve
  163. local resolvedObject = (fieldType.resolve or defaultResolver)(object, fields, {})
  164. local subSelections = mergeSelectionSets(fields)
  165. local responseValue = completeValue(fieldType.kind, resolvedObject, subSelections, context)
  166. return responseValue
  167. end
  168. evaluateSelections = function(objectType, object, selections, context)
  169. local groupedFieldSet = collectFields(objectType, selections, {}, {}, context)
  170. return util.map(groupedFieldSet, function(fields)
  171. return getFieldEntry(objectType, object, fields, context)
  172. end)
  173. end
  174. return function(schema, tree, variables, operationName, rootValue)
  175. local context = buildContext(schema, tree, variables, operationName)
  176. local rootType = schema[context.operation.operation]
  177. if not rootType then
  178. error('Unsupported operation "' .. context.operation.operation .. '"')
  179. end
  180. return evaluateSelections(rootType, rootValue, context.operation.selectionSet.selections, context)
  181. end