execute.lua 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. local path = (...):gsub('%.[^%.]+$', '')
  2. local types = require(path .. '.types')
  3. local util = require(path .. '.util')
  4. local function typeFromAST(node, schema)
  5. local innerType
  6. if node.kind == 'listType' then
  7. innerType = typeFromAST(node.type)
  8. return innerType and types.list(innerType)
  9. elseif node.kind == 'nonNullType' then
  10. innerType = typeFromAST(node.type)
  11. return innerType and types.nonNull(innerType)
  12. else
  13. assert(node.kind == 'namedType', 'Variable must be a named type')
  14. return schema:getType(node.name.value)
  15. end
  16. end
  17. local function getFieldResponseKey(field)
  18. return field.alias and field.alias.name.value or field.name.value
  19. end
  20. local function shouldIncludeNode(selection, context)
  21. if selection.directives then
  22. local function isDirectiveActive(key, _type)
  23. local directive = util.find(selection.directives, function(directive)
  24. return directive.name.value == key
  25. end)
  26. if not directive then return end
  27. local ifArgument = util.find(directive.arguments, function(argument)
  28. return argument.name.value == 'if'
  29. end)
  30. if not ifArgument then return end
  31. return util.coerceValue(ifArgument.value, _type.arguments['if'], context.variables)
  32. end
  33. if isDirectiveActive('skip', types.skip) then return false end
  34. if isDirectiveActive('include', types.include) == false then return false end
  35. end
  36. return true
  37. end
  38. local function doesFragmentApply(fragment, type, context)
  39. if not fragment.typeCondition then return true end
  40. local innerType = typeFromAST(fragment.typeCondition, context.schema)
  41. if innerType == type then
  42. return true
  43. elseif innerType.__type == 'Interface' then
  44. local implementors = context.schema:getImplementors(innerType.name)
  45. return implementors and implementors[type]
  46. elseif innerType.__type == 'Union' then
  47. return util.find(innerType.types, function(member)
  48. return member == type
  49. end)
  50. end
  51. end
  52. local function mergeSelectionSets(fields)
  53. local selections = {}
  54. for i = 1, #fields do
  55. local selectionSet = fields[i].selectionSet
  56. if selectionSet then
  57. for j = 1, #selectionSet.selections do
  58. table.insert(selections, selectionSet.selections[j])
  59. end
  60. end
  61. end
  62. return selections
  63. end
  64. local function defaultResolver(object, arguments, info)
  65. return object[info.fieldASTs[1].name.value]
  66. end
  67. local function buildContext(schema, tree, rootValue, variables, operationName)
  68. local context = {
  69. schema = schema,
  70. rootValue = rootValue,
  71. variables = variables,
  72. operation = nil,
  73. fragmentMap = {}
  74. }
  75. for _, definition in ipairs(tree.definitions) do
  76. if definition.kind == 'operation' then
  77. if not operationName and context.operation then
  78. error('Operation name must be specified if more than one operation exists.')
  79. end
  80. if not operationName or definition.name.value == operationName then
  81. context.operation = definition
  82. end
  83. elseif definition.kind == 'fragmentDefinition' then
  84. context.fragmentMap[definition.name.value] = definition
  85. end
  86. end
  87. if not context.operation then
  88. if operationName then
  89. error('Unknown operation "' .. operationName .. '"')
  90. else
  91. error('Must provide an operation')
  92. end
  93. end
  94. return context
  95. end
  96. local function collectFields(objectType, selections, visitedFragments, result, context)
  97. for _, selection in ipairs(selections) do
  98. if selection.kind == 'field' then
  99. if shouldIncludeNode(selection, context) then
  100. local name = getFieldResponseKey(selection)
  101. result[name] = result[name] or {}
  102. table.insert(result[name], selection)
  103. end
  104. elseif selection.kind == 'inlineFragment' then
  105. if shouldIncludeNode(selection, context) and doesFragmentApply(selection, objectType, context) then
  106. collectFields(objectType, selection.selectionSet.selections, visitedFragments, result, context)
  107. end
  108. elseif selection.kind == 'fragmentSpread' then
  109. local fragmentName = selection.name.value
  110. if shouldIncludeNode(selection, context) and not visitedFragments[fragmentName] then
  111. visitedFragments[fragmentName] = true
  112. local fragment = context.fragmentMap[fragmentName]
  113. if fragment and shouldIncludeNode(fragment, context) and doesFragmentApply(fragment, objectType, context) then
  114. collectFields(objectType, fragment.selectionSet.selections, visitedFragments, result, context)
  115. end
  116. end
  117. end
  118. end
  119. return result
  120. end
  121. local evaluateSelections
  122. local function completeValue(fieldType, result, subSelections, context)
  123. local fieldTypeName = fieldType.__type
  124. if fieldTypeName == 'NonNull' then
  125. local innerType = fieldType.ofType
  126. local completedResult = completeValue(innerType, result, context)
  127. if completedResult == nil then
  128. error('No value provided for non-null ' .. innerType.name)
  129. end
  130. return completedResult
  131. end
  132. if result == nil then
  133. return nil
  134. end
  135. if fieldTypeName == 'List' then
  136. local innerType = fieldType.ofType
  137. if type(result) ~= 'table' then
  138. error('Expected a table for ' .. innerType.name .. ' list')
  139. end
  140. local values = {}
  141. for i, value in ipairs(result) do
  142. values[i] = completeValue(innerType, value, subSelections, context)
  143. end
  144. return values
  145. end
  146. if fieldTypeName == 'Scalar' or fieldTypeName == 'Enum' then
  147. return fieldType.serialize(result)
  148. end
  149. if fieldTypeName == 'Object' then
  150. return evaluateSelections(fieldType, result, subSelections, context)
  151. elseif fieldTypeName == 'Interface' or fieldTypeName == 'Union' then
  152. local objectType = fieldType.resolveType(result)
  153. return evaluateSelections(objectType, result, subSelections, context)
  154. end
  155. error('Unknown type "' .. fieldTypeName .. '" for field "' .. field.name .. '"')
  156. end
  157. local function getFieldEntry(objectType, object, fields, context)
  158. local firstField = fields[1]
  159. local fieldName = firstField.name.value
  160. local responseKey = getFieldResponseKey(firstField)
  161. local fieldType = objectType.fields[fieldName]
  162. if fieldType == nil then
  163. return nil
  164. end
  165. local argumentMap = {}
  166. for _, argument in ipairs(firstField.arguments or {}) do
  167. argumentMap[argument.name.value] = argument
  168. end
  169. local arguments = util.map(fieldType.arguments, function(argument, name)
  170. local supplied = argumentMap[name] and argumentMap[name].value
  171. return supplied and util.coerceValue(supplied, argument, context.variables) or argument.defaultValue
  172. end)
  173. local info = {
  174. fieldName = fieldName,
  175. fieldASTs = fields,
  176. returnType = fieldType.kind,
  177. parentType = objectType,
  178. schema = context.schema,
  179. fragments = context.fragmentMap,
  180. rootValue = context.rootValue,
  181. operation = context.operation,
  182. variableValues = context.variables
  183. }
  184. local resolvedObject = (fieldType.resolve or defaultResolver)(object, arguments, info)
  185. local subSelections = mergeSelectionSets(fields)
  186. local responseValue = completeValue(fieldType.kind, resolvedObject, subSelections, context)
  187. return responseValue
  188. end
  189. evaluateSelections = function(objectType, object, selections, context)
  190. local groupedFieldSet = collectFields(objectType, selections, {}, {}, context)
  191. return util.map(groupedFieldSet, function(fields)
  192. return getFieldEntry(objectType, object, fields, context)
  193. end)
  194. end
  195. return function(schema, tree, rootValue, variables, operationName)
  196. local context = buildContext(schema, tree, rootValue, variables, operationName)
  197. local rootType = schema[context.operation.operation]
  198. if not rootType then
  199. error('Unsupported operation "' .. context.operation.operation .. '"')
  200. end
  201. return evaluateSelections(rootType, rootValue, context.operation.selectionSet.selections, context)
  202. end