execute.lua 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. local types = require 'types'
  2. local util = require 'util'
  3. local function typeFromAST(node, schema)
  4. local innerType
  5. if node.kind == 'listType' then
  6. innerType = typeFromAST(node.type)
  7. return innerType and types.list(innerType)
  8. elseif node.kind == 'nonNullType' then
  9. innerType = typeFromAST(node.type)
  10. return innerType and types.nonNull(innerType)
  11. else
  12. assert(node.kind == 'namedType', 'Variable must be a named type')
  13. return schema:getType(node.name.value)
  14. end
  15. end
  16. local function getFieldResponseKey(field)
  17. return field.alias and field.alias.name.value or field.name.value
  18. end
  19. local function shouldIncludeNode(selection, context)
  20. if selection.directives then
  21. local function isDirectiveActive(key, _type)
  22. local directive = util.find(selection.directives, function(directive)
  23. return directive.name.value == key
  24. end)
  25. if not directive then return end
  26. local ifArgument = util.find(directive.arguments, function(argument)
  27. return argument.name.value == 'if'
  28. end)
  29. if not ifArgument then return end
  30. return util.coerceValue(ifArgument.value, _type.arguments['if'], context.variables)
  31. end
  32. if isDirectiveActive('skip', types.skip) then return false end
  33. if isDirectiveActive('include', types.include) == false then return false end
  34. end
  35. return true
  36. end
  37. local function doesFragmentApply(fragment, type, context)
  38. if not fragment.typeCondition then return true end
  39. local innerType = typeFromAST(fragment.typeCondition, context.schema)
  40. if innerType == type then
  41. return true
  42. elseif innerType.__type == 'Interface' then
  43. local implementors = context.schema:getImplementors(innerType.name)
  44. return implementors and implementors[type]
  45. elseif innerType.__type == 'Union' then
  46. return util.find(innerType.types, function(member)
  47. return member == type
  48. end)
  49. end
  50. end
  51. local function mergeSelectionSets(fields)
  52. local selections = {}
  53. for i = 1, #fields do
  54. local selectionSet = fields[i].selectionSet
  55. if selectionSet then
  56. for j = 1, #selectionSet.selections do
  57. table.insert(selections, selectionSet.selections[j])
  58. end
  59. end
  60. end
  61. return selections
  62. end
  63. local function defaultResolver(object, arguments, info)
  64. return object[info.fieldASTs[1].name.value]
  65. end
  66. local function buildContext(schema, tree, rootValue, variables, operationName)
  67. local context = {
  68. schema = schema,
  69. rootValue = rootValue,
  70. variables = variables,
  71. operation = nil,
  72. fragmentMap = {}
  73. }
  74. for _, definition in ipairs(tree.definitions) do
  75. if definition.kind == 'operation' then
  76. if not operationName and context.operation then
  77. error('Operation name must be specified if more than one operation exists.')
  78. end
  79. if not operationName or definition.name.value == operationName then
  80. context.operation = definition
  81. end
  82. elseif definition.kind == 'fragmentDefinition' then
  83. context.fragmentMap[definition.name.value] = definition
  84. end
  85. end
  86. if not context.operation then
  87. if operationName then
  88. error('Unknown operation "' .. operationName .. '"')
  89. else
  90. error('Must provide an operation')
  91. end
  92. end
  93. return context
  94. end
  95. local function collectFields(objectType, selections, visitedFragments, result, context)
  96. for _, selection in ipairs(selections) do
  97. if selection.kind == 'field' then
  98. if shouldIncludeNode(selection, context) then
  99. local name = getFieldResponseKey(selection)
  100. result[name] = result[name] or {}
  101. table.insert(result[name], selection)
  102. end
  103. elseif selection.kind == 'inlineFragment' then
  104. if shouldIncludeNode(selection, context) and doesFragmentApply(selection, objectType, context) then
  105. collectFields(objectType, selection.selectionSet.selections, visitedFragments, result, context)
  106. end
  107. elseif selection.kind == 'fragmentSpread' then
  108. local fragmentName = selection.name.value
  109. if shouldIncludeNode(selection, context) and not visitedFragments[fragmentName] then
  110. visitedFragments[fragmentName] = true
  111. local fragment = context.fragmentMap[fragmentName]
  112. if fragment and shouldIncludeNode(fragment, context) and doesFragmentApply(fragment, objectType, context) then
  113. collectFields(objectType, fragment.selectionSet.selections, visitedFragments, result, context)
  114. end
  115. end
  116. end
  117. end
  118. return result
  119. end
  120. local evaluateSelections
  121. local function completeValue(fieldType, result, subSelections, context)
  122. local fieldTypeName = fieldType.__type
  123. if fieldTypeName == 'NonNull' then
  124. local innerType = fieldType.ofType
  125. local completedResult = completeValue(innerType, result, context)
  126. if completedResult == nil then
  127. error('No value provided for non-null ' .. innerType.name)
  128. end
  129. return completedResult
  130. end
  131. if result == nil then
  132. return nil
  133. end
  134. if fieldTypeName == 'List' then
  135. local innerType = fieldType.ofType
  136. if type(result) ~= 'table' then
  137. error('Expected a table for ' .. innerType.name .. ' list')
  138. end
  139. local values = {}
  140. for i, value in ipairs(values) do
  141. values[i] = completeValue(innerType, value, context)
  142. end
  143. return values
  144. end
  145. if fieldTypeName == 'Scalar' or fieldTypeName == 'Enum' then
  146. return fieldType.serialize(result)
  147. end
  148. if fieldTypeName == 'Object' then
  149. return evaluateSelections(fieldType, result, subSelections, context)
  150. elseif fieldTypeName == 'Interface' or fieldTypeName == 'Union' then
  151. local objectType = fieldType.resolveType(result)
  152. return evaluateSelections(objectType, result, subSelections, context)
  153. end
  154. error('Unknown type "' .. fieldTypeName .. '" for field "' .. field.name .. '"')
  155. end
  156. local function getFieldEntry(objectType, object, fields, context)
  157. local firstField = fields[1]
  158. local fieldName = firstField.name.value
  159. local responseKey = getFieldResponseKey(firstField)
  160. local fieldType = objectType.fields[fieldName]
  161. if fieldType == nil then
  162. return nil
  163. end
  164. local argumentMap = {}
  165. for _, argument in ipairs(firstField.arguments or {}) do
  166. argumentMap[argument.name.value] = argument
  167. end
  168. local arguments = util.map(fieldType.arguments, function(argument, name)
  169. local supplied = argumentMap[name] and argumentMap[name].value
  170. return supplied and util.coerceValue(supplied, argument, context.variables) or argument.defaultValue
  171. end)
  172. local info = {
  173. fieldName = fieldName,
  174. fieldASTs = fields,
  175. returnType = fieldType.kind,
  176. parentType = objectType,
  177. schema = context.schema,
  178. fragments = context.fragmentMap,
  179. rootValue = context.rootValue,
  180. operation = context.operation,
  181. variableValues = context.variables
  182. }
  183. local resolvedObject = (fieldType.resolve or defaultResolver)(object, arguments, info)
  184. local subSelections = mergeSelectionSets(fields)
  185. local responseValue = completeValue(fieldType.kind, resolvedObject, subSelections, context)
  186. return responseValue
  187. end
  188. evaluateSelections = function(objectType, object, selections, context)
  189. local groupedFieldSet = collectFields(objectType, selections, {}, {}, context)
  190. return util.map(groupedFieldSet, function(fields)
  191. return getFieldEntry(objectType, object, fields, context)
  192. end)
  193. end
  194. return function(schema, tree, rootValue, variables, operationName)
  195. local context = buildContext(schema, tree, rootValue, variables, operationName)
  196. local rootType = schema[context.operation.operation]
  197. if not rootType then
  198. error('Unsupported operation "' .. context.operation.operation .. '"')
  199. end
  200. return evaluateSelections(rootType, rootValue, context.operation.selectionSet.selections, context)
  201. end