execute.lua 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. local path = (...):gsub('%.[^%.]+$', '')
  2. local types = require(path .. '.types')
  3. local util = require(path .. '.util')
  4. local introspection = require(path .. '.introspection')
  5. local cjson = require 'cjson' -- needs to be cloned from here https://github.com/openresty/lua-cjson for cjson.empty_array feature
  6. local function typeFromAST(node, schema)
  7. local innerType
  8. if node.kind == 'listType' then
  9. innerType = typeFromAST(node.type)
  10. return innerType and types.list(innerType)
  11. elseif node.kind == 'nonNullType' then
  12. innerType = typeFromAST(node.type)
  13. return innerType and types.nonNull(innerType)
  14. else
  15. assert(node.kind == 'namedType', 'Variable must be a named type')
  16. return schema:getType(node.name.value)
  17. end
  18. end
  19. local function getFieldResponseKey(field)
  20. return field.alias and field.alias.name.value or field.name.value
  21. end
  22. local function shouldIncludeNode(selection, context)
  23. if selection.directives then
  24. local function isDirectiveActive(key, _type)
  25. local directive = util.find(selection.directives, function(directive)
  26. return directive.name.value == key
  27. end)
  28. if not directive then return end
  29. local ifArgument = util.find(directive.arguments, function(argument)
  30. return argument.name.value == 'if'
  31. end)
  32. if not ifArgument then return end
  33. return util.coerceValue(ifArgument.value, _type.arguments['if'], context.variables)
  34. end
  35. if isDirectiveActive('skip', types.skip) then return false end
  36. if isDirectiveActive('include', types.include) == false then return false end
  37. end
  38. return true
  39. end
  40. local function doesFragmentApply(fragment, type, context)
  41. if not fragment.typeCondition then return true end
  42. local innerType = typeFromAST(fragment.typeCondition, context.schema)
  43. if innerType == type then
  44. return true
  45. elseif innerType.__type == 'Interface' then
  46. local implementors = context.schema:getImplementors(innerType.name)
  47. return implementors and implementors[type]
  48. elseif innerType.__type == 'Union' then
  49. return util.find(innerType.types, function(member)
  50. return member == type
  51. end)
  52. end
  53. end
  54. local function mergeSelectionSets(fields)
  55. local selections = {}
  56. for i = 1, #fields do
  57. local selectionSet = fields[i].selectionSet
  58. if selectionSet then
  59. for j = 1, #selectionSet.selections do
  60. table.insert(selections, selectionSet.selections[j])
  61. end
  62. end
  63. end
  64. return selections
  65. end
  66. local function defaultResolver(object, arguments, info)
  67. return object[info.fieldASTs[1].name.value]
  68. end
  69. local function buildContext(schema, tree, rootValue, variables, operationName)
  70. local context = {
  71. schema = schema,
  72. rootValue = rootValue,
  73. variables = variables,
  74. operation = nil,
  75. fragmentMap = {}
  76. }
  77. for _, definition in ipairs(tree.definitions) do
  78. if definition.kind == 'operation' then
  79. if not operationName and context.operation then
  80. error('Operation name must be specified if more than one operation exists.')
  81. end
  82. if not operationName or definition.name.value == operationName then
  83. context.operation = definition
  84. end
  85. elseif definition.kind == 'fragmentDefinition' then
  86. context.fragmentMap[definition.name.value] = definition
  87. end
  88. end
  89. if not context.operation then
  90. if operationName then
  91. error('Unknown operation "' .. operationName .. '"')
  92. else
  93. error('Must provide an operation')
  94. end
  95. end
  96. return context
  97. end
  98. local function collectFields(objectType, selections, visitedFragments, result, context)
  99. for _, selection in ipairs(selections) do
  100. if selection.kind == 'field' then
  101. if shouldIncludeNode(selection, context) then
  102. local name = getFieldResponseKey(selection)
  103. result[name] = result[name] or {}
  104. table.insert(result[name], selection)
  105. end
  106. elseif selection.kind == 'inlineFragment' then
  107. if shouldIncludeNode(selection, context) and doesFragmentApply(selection, objectType, context) then
  108. collectFields(objectType, selection.selectionSet.selections, visitedFragments, result, context)
  109. end
  110. elseif selection.kind == 'fragmentSpread' then
  111. local fragmentName = selection.name.value
  112. if shouldIncludeNode(selection, context) and not visitedFragments[fragmentName] then
  113. visitedFragments[fragmentName] = true
  114. local fragment = context.fragmentMap[fragmentName]
  115. if fragment and shouldIncludeNode(fragment, context) and doesFragmentApply(fragment, objectType, context) then
  116. collectFields(objectType, fragment.selectionSet.selections, visitedFragments, result, context)
  117. end
  118. end
  119. end
  120. end
  121. return result
  122. end
  123. local evaluateSelections
  124. local function completeValue(fieldType, result, subSelections, context)
  125. local fieldTypeName = fieldType.__type
  126. if fieldTypeName == 'NonNull' then
  127. local innerType = fieldType.ofType
  128. local completedResult = completeValue(innerType, result, subSelections, context)
  129. if completedResult == nil then
  130. error('No value provided for non-null ' .. innerType.name)
  131. end
  132. return completedResult
  133. end
  134. if result == nil then
  135. return nil
  136. end
  137. if fieldTypeName == 'List' then
  138. if result == cjson.empty_array then return result end
  139. local innerType = fieldType.ofType
  140. if type(result) ~= 'table' then
  141. error('Expected a table for ' .. innerType.name .. ' list')
  142. end
  143. local values = {}
  144. for i, value in ipairs(result) do
  145. values[i] = completeValue(innerType, value, subSelections, context)
  146. end
  147. return values
  148. end
  149. if fieldTypeName == 'Scalar' then
  150. return fieldType.serialize(result)
  151. end
  152. if fieldTypeName == 'Enum' then
  153. return fieldType:serialize(result)
  154. end
  155. if fieldTypeName == 'Object' then
  156. return evaluateSelections(fieldType, result, subSelections, context)
  157. elseif fieldTypeName == 'Interface' or fieldTypeName == 'Union' then
  158. local objectType = fieldType.resolveType(result)
  159. return evaluateSelections(objectType, result, subSelections, context)
  160. end
  161. error('Unknown type "' .. fieldTypeName .. '" for field "' .. field.name .. '"')
  162. end
  163. local function getFieldEntry(objectType, object, fields, context)
  164. local firstField = fields[1]
  165. local fieldName = firstField.name.value
  166. local responseKey = getFieldResponseKey(firstField)
  167. local fieldType
  168. if fieldName == '__schema' then
  169. fieldType = introspection.SchemaMetaFieldDef
  170. elseif fieldName == '__type' then
  171. fieldType = introspection.TypeMetaFieldDef
  172. elseif fieldName == '__typename' then
  173. fieldType = introspection.TypeNameMetaFieldDef
  174. else
  175. fieldType = objectType.fields[fieldName]
  176. end
  177. if fieldType == nil then
  178. return nil
  179. end
  180. local argumentMap = {}
  181. for _, argument in ipairs(firstField.arguments or {}) do
  182. argumentMap[argument.name.value] = argument
  183. end
  184. local arguments = util.map(fieldType.arguments or {}, function(argument, name)
  185. local supplied = argumentMap[name] and argumentMap[name].value
  186. return supplied and util.coerceValue(argumentMap[name].value, argument.kind, context.variables) or argument.defaultValue
  187. end)
  188. local info = {
  189. fieldName = fieldName,
  190. fieldASTs = fields,
  191. returnType = fieldType.kind,
  192. parentType = objectType,
  193. schema = context.schema,
  194. fragments = context.fragmentMap,
  195. rootValue = context.rootValue,
  196. operation = context.operation,
  197. variableValues = context.variables
  198. }
  199. local resolvedObject = (fieldType.resolve or defaultResolver)(object, arguments, info)
  200. local subSelections = mergeSelectionSets(fields)
  201. local responseValue = completeValue(fieldType.kind, resolvedObject, subSelections, context)
  202. return responseValue
  203. end
  204. evaluateSelections = function(objectType, object, selections, context)
  205. local groupedFieldSet = collectFields(objectType, selections, {}, {}, context)
  206. return util.map(groupedFieldSet, function(fields)
  207. local v = getFieldEntry(objectType, object, fields, context)
  208. if v ~= nil then return v else return cjson.null end
  209. end)
  210. end
  211. return function(schema, tree, rootValue, variables, operationName)
  212. local context = buildContext(schema, tree, rootValue, variables, operationName)
  213. local rootType = schema[context.operation.operation]
  214. if not rootType then
  215. error('Unsupported operation "' .. context.operation.operation .. '"')
  216. end
  217. return evaluateSelections(rootType, rootValue, context.operation.selectionSet.selections, context)
  218. end