2
0

execute.lua 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. local path = (...):gsub('%.[^%.]+$', '')
  2. local types = require(path .. '.types')
  3. local util = require(path .. '.util')
  4. local introspection = require(path .. '.introspection')
  5. local function typeFromAST(node, schema)
  6. local innerType
  7. if node.kind == 'listType' then
  8. innerType = typeFromAST(node.type)
  9. return innerType and types.list(innerType)
  10. elseif node.kind == 'nonNullType' then
  11. innerType = typeFromAST(node.type)
  12. return innerType and types.nonNull(innerType)
  13. else
  14. assert(node.kind == 'namedType', 'Variable must be a named type')
  15. return schema:getType(node.name.value)
  16. end
  17. end
  18. local function getFieldResponseKey(field)
  19. return field.alias and field.alias.name.value or field.name.value
  20. end
  21. local function shouldIncludeNode(selection, context)
  22. if selection.directives then
  23. local function isDirectiveActive(key, _type)
  24. local directive = util.find(selection.directives, function(directive)
  25. return directive.name.value == key
  26. end)
  27. if not directive then return end
  28. local ifArgument = util.find(directive.arguments, function(argument)
  29. return argument.name.value == 'if'
  30. end)
  31. if not ifArgument then return end
  32. return util.coerceValue(ifArgument.value, _type.arguments['if'], context.variables)
  33. end
  34. if isDirectiveActive('skip', types.skip) then return false end
  35. if isDirectiveActive('include', types.include) == false then return false end
  36. end
  37. return true
  38. end
  39. local function doesFragmentApply(fragment, type, context)
  40. if not fragment.typeCondition then return true end
  41. local innerType = typeFromAST(fragment.typeCondition, context.schema)
  42. if innerType == type then
  43. return true
  44. elseif innerType.__type == 'Interface' then
  45. local implementors = context.schema:getImplementors(innerType.name)
  46. return implementors and implementors[type]
  47. elseif innerType.__type == 'Union' then
  48. return util.find(innerType.types, function(member)
  49. return member == type
  50. end)
  51. end
  52. end
  53. local function mergeSelectionSets(fields)
  54. local selections = {}
  55. for i = 1, #fields do
  56. local selectionSet = fields[i].selectionSet
  57. if selectionSet then
  58. for j = 1, #selectionSet.selections do
  59. table.insert(selections, selectionSet.selections[j])
  60. end
  61. end
  62. end
  63. return selections
  64. end
  65. local function defaultResolver(object, arguments, info)
  66. return object[info.fieldASTs[1].name.value]
  67. end
  68. local function buildContext(schema, tree, rootValue, variables, operationName)
  69. local context = {
  70. schema = schema,
  71. rootValue = rootValue,
  72. variables = variables,
  73. operation = nil,
  74. fragmentMap = {}
  75. }
  76. for _, definition in ipairs(tree.definitions) do
  77. if definition.kind == 'operation' then
  78. if not operationName and context.operation then
  79. error('Operation name must be specified if more than one operation exists.')
  80. end
  81. if not operationName or definition.name.value == operationName then
  82. context.operation = definition
  83. end
  84. elseif definition.kind == 'fragmentDefinition' then
  85. context.fragmentMap[definition.name.value] = definition
  86. end
  87. end
  88. if not context.operation then
  89. if operationName then
  90. error('Unknown operation "' .. operationName .. '"')
  91. else
  92. error('Must provide an operation')
  93. end
  94. end
  95. return context
  96. end
  97. local function collectFields(objectType, selections, visitedFragments, result, context)
  98. for _, selection in ipairs(selections) do
  99. if selection.kind == 'field' then
  100. if shouldIncludeNode(selection, context) then
  101. local name = getFieldResponseKey(selection)
  102. result[name] = result[name] or {}
  103. table.insert(result[name], selection)
  104. end
  105. elseif selection.kind == 'inlineFragment' then
  106. if shouldIncludeNode(selection, context) and doesFragmentApply(selection, objectType, context) then
  107. collectFields(objectType, selection.selectionSet.selections, visitedFragments, result, context)
  108. end
  109. elseif selection.kind == 'fragmentSpread' then
  110. local fragmentName = selection.name.value
  111. if shouldIncludeNode(selection, context) and not visitedFragments[fragmentName] then
  112. visitedFragments[fragmentName] = true
  113. local fragment = context.fragmentMap[fragmentName]
  114. if fragment and shouldIncludeNode(fragment, context) and doesFragmentApply(fragment, objectType, context) then
  115. collectFields(objectType, fragment.selectionSet.selections, visitedFragments, result, context)
  116. end
  117. end
  118. end
  119. end
  120. return result
  121. end
  122. local evaluateSelections
  123. local function completeValue(fieldType, result, subSelections, context)
  124. local fieldTypeName = fieldType.__type
  125. if fieldTypeName == 'NonNull' then
  126. local innerType = fieldType.ofType
  127. local completedResult = completeValue(innerType, result, subSelections, context)
  128. if completedResult == nil then
  129. error('No value provided for non-null ' .. (innerType.name or innerType.__type))
  130. end
  131. return completedResult
  132. end
  133. if result == nil then
  134. return nil
  135. end
  136. if fieldTypeName == 'List' then
  137. local innerType = fieldType.ofType
  138. if type(result) ~= 'table' then
  139. error('Expected a table for ' .. innerType.name .. ' list')
  140. end
  141. local values = {}
  142. for i, value in ipairs(result) do
  143. values[i] = completeValue(innerType, value, subSelections, context)
  144. end
  145. return next(values) and values or context.schema.__emptyList
  146. end
  147. if fieldTypeName == 'Scalar' or fieldTypeName == 'Enum' then
  148. return fieldType.serialize(result)
  149. end
  150. if fieldTypeName == 'Object' then
  151. local fields = evaluateSelections(fieldType, result, subSelections, context)
  152. return next(fields) and fields or context.schema.__emptyObject
  153. elseif fieldTypeName == 'Interface' or fieldTypeName == 'Union' then
  154. local objectType = fieldType.resolveType(result)
  155. return evaluateSelections(objectType, result, subSelections, context)
  156. end
  157. error('Unknown type "' .. fieldTypeName .. '" for field "' .. field.name .. '"')
  158. end
  159. local function getFieldEntry(objectType, object, fields, context)
  160. local firstField = fields[1]
  161. local fieldName = firstField.name.value
  162. local responseKey = getFieldResponseKey(firstField)
  163. local fieldType = introspection.fieldMap[fieldName] or objectType.fields[fieldName]
  164. if fieldType == nil then
  165. return nil
  166. end
  167. local argumentMap = {}
  168. for _, argument in ipairs(firstField.arguments or {}) do
  169. argumentMap[argument.name.value] = argument
  170. end
  171. local arguments = util.map(fieldType.arguments or {}, function(argument, name)
  172. local supplied = argumentMap[name] and argumentMap[name].value
  173. return supplied and util.coerceValue(supplied, argument, context.variables) or argument.defaultValue
  174. end)
  175. local info = {
  176. fieldName = fieldName,
  177. fieldASTs = fields,
  178. returnType = fieldType.kind,
  179. parentType = objectType,
  180. schema = context.schema,
  181. fragments = context.fragmentMap,
  182. rootValue = context.rootValue,
  183. operation = context.operation,
  184. variableValues = context.variables
  185. }
  186. local resolvedObject = (fieldType.resolve or defaultResolver)(object, arguments, info)
  187. local subSelections = mergeSelectionSets(fields)
  188. return completeValue(fieldType.kind, resolvedObject, subSelections, context)
  189. end
  190. evaluateSelections = function(objectType, object, selections, context)
  191. local groupedFieldSet = collectFields(objectType, selections, {}, {}, context)
  192. return util.map(groupedFieldSet, function(fields)
  193. return getFieldEntry(objectType, object, fields, context)
  194. end)
  195. end
  196. return function(schema, tree, rootValue, variables, operationName)
  197. local context = buildContext(schema, tree, rootValue, variables, operationName)
  198. local rootType = schema[context.operation.operation]
  199. if not rootType then
  200. error('Unsupported operation "' .. context.operation.operation .. '"')
  201. end
  202. return evaluateSelections(rootType, rootValue, context.operation.selectionSet.selections, context)
  203. end