rules.lua 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. local path = (...):gsub('%.[^%.]+$', '')
  2. local types = require(path .. '.types')
  3. local util = require(path .. '.util')
  4. local getParentField = require(path .. '.schema').getParentField
  5. local rules = {}
  6. function rules.uniqueOperationNames(node, context)
  7. local name = node.name and node.name.value
  8. if name then
  9. if context.operationNames[name] then
  10. error('Multiple operations exist named "' .. name .. '"')
  11. end
  12. context.operationNames[name] = true
  13. end
  14. end
  15. function rules.loneAnonymousOperation(node, context)
  16. local name = node.name and node.name.value
  17. if context.hasAnonymousOperation or (not name and next(context.operationNames)) then
  18. error('Cannot have more than one operation when using anonymous operations')
  19. end
  20. if not name then
  21. context.hasAnonymousOperation = true
  22. end
  23. end
  24. function rules.fieldsDefinedOnType(node, context)
  25. if context.objects[#context.objects] == false then
  26. local parent = context.objects[#context.objects - 1]
  27. if parent.ofType then parent = parent.ofType end
  28. error('Field "' .. node.name.value .. '" is not defined on type "' .. parent.name .. '"')
  29. end
  30. end
  31. function rules.argumentsDefinedOnType(node, context)
  32. if node.arguments then
  33. local parentField = getParentField(context, node.name.value)
  34. for _, argument in pairs(node.arguments) do
  35. local name = argument.name.value
  36. if not parentField.arguments[name] then
  37. error('Non-existent argument "' .. name .. '"')
  38. end
  39. end
  40. end
  41. end
  42. function rules.scalarFieldsAreLeaves(node, context)
  43. if context.objects[#context.objects].__type == 'Scalar' and node.selectionSet then
  44. error('Scalar values cannot have subselections')
  45. end
  46. end
  47. function rules.compositeFieldsAreNotLeaves(node, context)
  48. local _type = context.objects[#context.objects].__type
  49. local isCompositeType = _type == 'Object' or _type == 'Interface' or _type == 'Union'
  50. if isCompositeType and not node.selectionSet then
  51. error('Composite types must have subselections')
  52. end
  53. end
  54. function rules.unambiguousSelections(node, context)
  55. local selectionMap = {}
  56. local seen = {}
  57. local function findConflict(entryA, entryB)
  58. -- Parent types can't overlap if they're different objects.
  59. -- Interface and union types may overlap.
  60. if entryA.parent ~= entryB.parent and entryA.__type == 'Object' and entryB.__type == 'Object' then
  61. return
  62. end
  63. -- Error if there are aliases that map two different fields to the same name.
  64. if entryA.field.name.value ~= entryB.field.name.value then
  65. return 'Type name mismatch'
  66. end
  67. -- Error if there are fields with the same name that have different return types.
  68. if entryA.definition and entryB.definition and entryA.definition ~= entryB.definition then
  69. return 'Return type mismatch'
  70. end
  71. -- Error if arguments are not identical for two fields with the same name.
  72. local argsA = entryA.field.arguments or {}
  73. local argsB = entryB.field.arguments or {}
  74. if #argsA ~= #argsB then
  75. return 'Argument mismatch'
  76. end
  77. local argMap = {}
  78. for i = 1, #argsA do
  79. argMap[argsA[i].name.value] = argsA[i].value
  80. end
  81. for i = 1, #argsB do
  82. local name = argsB[i].name.value
  83. if not argMap[name] then
  84. return 'Argument mismatch'
  85. elseif argMap[name].kind ~= argsB[i].value.kind then
  86. return 'Argument mismatch'
  87. elseif argMap[name].value ~= argsB[i].value.value then
  88. return 'Argument mismatch'
  89. end
  90. end
  91. end
  92. local function validateField(key, entry)
  93. if selectionMap[key] then
  94. for i = 1, #selectionMap[key] do
  95. local conflict = findConflict(selectionMap[key][i], entry)
  96. if conflict then
  97. error(conflict)
  98. end
  99. end
  100. table.insert(selectionMap[key], entry)
  101. else
  102. selectionMap[key] = { entry }
  103. end
  104. end
  105. -- Recursively make sure that there are no ambiguous selections with the same name.
  106. local function validateSelectionSet(selectionSet, parentType)
  107. for _, selection in ipairs(selectionSet.selections) do
  108. if selection.kind == 'field' then
  109. if not parentType or not parentType.fields or not parentType.fields[selection.name.value] then return end
  110. local key = selection.alias and selection.alias.name.value or selection.name.value
  111. local definition = parentType.fields[selection.name.value].kind
  112. local fieldEntry = {
  113. parent = parentType,
  114. field = selection,
  115. definition = definition
  116. }
  117. validateField(key, fieldEntry)
  118. elseif selection.kind == 'inlineFragment' then
  119. local parentType = selection.typeCondition and context.schema:getType(selection.typeCondition.name.value) or parentType
  120. validateSelectionSet(selection.selectionSet, parentType)
  121. elseif selection.kind == 'fragmentSpread' then
  122. local fragmentDefinition = context.fragmentMap[selection.name.value]
  123. if fragmentDefinition and not seen[fragmentDefinition] then
  124. seen[fragmentDefinition] = true
  125. if fragmentDefinition and fragmentDefinition.typeCondition then
  126. local parentType = context.schema:getType(fragmentDefinition.typeCondition.name.value)
  127. validateSelectionSet(fragmentDefinition.selectionSet, parentType)
  128. end
  129. end
  130. end
  131. end
  132. end
  133. validateSelectionSet(node, context.objects[#context.objects])
  134. end
  135. function rules.uniqueArgumentNames(node, context)
  136. if node.arguments then
  137. local arguments = {}
  138. for _, argument in ipairs(node.arguments) do
  139. local name = argument.name.value
  140. if arguments[name] then
  141. error('Encountered multiple arguments named "' .. name .. '"')
  142. end
  143. arguments[name] = true
  144. end
  145. end
  146. end
  147. function rules.argumentsOfCorrectType(node, context)
  148. if node.arguments then
  149. local parentField = getParentField(context, node.name.value)
  150. for _, argument in pairs(node.arguments) do
  151. local name = argument.name.value
  152. local argumentType = parentField.arguments[name]
  153. util.coerceValue(argument.value, argumentType.kind or argumentType)
  154. end
  155. end
  156. end
  157. function rules.requiredArgumentsPresent(node, context)
  158. local arguments = node.arguments or {}
  159. local parentField = getParentField(context, node.name.value)
  160. for name, argument in pairs(parentField.arguments) do
  161. if argument.__type == 'NonNull' then
  162. local present = util.find(arguments, function(argument)
  163. return argument.name.value == name
  164. end)
  165. if not present then
  166. error('Required argument "' .. name .. '" was not supplied.')
  167. end
  168. end
  169. end
  170. end
  171. function rules.uniqueFragmentNames(node, context)
  172. local fragments = {}
  173. for _, definition in ipairs(node.definitions) do
  174. if definition.kind == 'fragmentDefinition' then
  175. local name = definition.name.value
  176. if fragments[name] then
  177. error('Encountered multiple fragments named "' .. name .. '"')
  178. end
  179. fragments[name] = true
  180. end
  181. end
  182. end
  183. function rules.fragmentHasValidType(node, context)
  184. if not node.typeCondition then return end
  185. local name = node.typeCondition.name.value
  186. local kind = context.schema:getType(name)
  187. if not kind then
  188. error('Fragment refers to non-existent type "' .. name .. '"')
  189. end
  190. if kind.__type ~= 'Object' and kind.__type ~= 'Interface' and kind.__type ~= 'Union' then
  191. error('Fragment type must be an Object, Interface, or Union, got ' .. kind.__type)
  192. end
  193. end
  194. function rules.noUnusedFragments(node, context)
  195. for _, definition in ipairs(node.definitions) do
  196. if definition.kind == 'fragmentDefinition' then
  197. local name = definition.name.value
  198. if not context.usedFragments[name] then
  199. error('Fragment "' .. name .. '" was not used.')
  200. end
  201. end
  202. end
  203. end
  204. function rules.fragmentSpreadTargetDefined(node, context)
  205. if not context.fragmentMap[node.name.value] then
  206. error('Fragment spread refers to non-existent fragment "' .. node.name.value .. '"')
  207. end
  208. end
  209. function rules.fragmentDefinitionHasNoCycles(node, context)
  210. local seen = { [node.name.value] = true }
  211. local function detectCycles(selectionSet)
  212. for _, selection in ipairs(selectionSet.selections) do
  213. if selection.kind == 'inlineFragment' then
  214. detectCycles(selection.selectionSet)
  215. elseif selection.kind == 'fragmentSpread' then
  216. if seen[selection.name.value] then
  217. error('Fragment definition has cycles')
  218. end
  219. seen[selection.name.value] = true
  220. local fragmentDefinition = context.fragmentMap[selection.name.value]
  221. if fragmentDefinition and fragmentDefinition.typeCondition then
  222. detectCycles(fragmentDefinition.selectionSet)
  223. end
  224. end
  225. end
  226. end
  227. detectCycles(node.selectionSet)
  228. end
  229. function rules.fragmentSpreadIsPossible(node, context)
  230. local fragment = node.kind == 'inlineFragment' and node or context.fragmentMap[node.name.value]
  231. local parentType = context.objects[#context.objects - 1]
  232. if parentType.ofType then parentType = parentType.ofType end
  233. if parentType.ofType then parentType = parentType.ofType end
  234. if parentType.ofType then parentType = parentType.ofType end
  235. local fragmentType
  236. if node.kind == 'inlineFragment' then
  237. fragmentType = node.typeCondition and context.schema:getType(node.typeCondition.name.value) or parentType
  238. else
  239. fragmentType = context.schema:getType(fragment.typeCondition.name.value)
  240. end
  241. -- Some types are not present in the schema. Let other rules handle this.
  242. if not parentType or not fragmentType then return end
  243. local function getTypes(kind)
  244. if kind.__type == 'Object' then
  245. return { [kind] = kind }
  246. elseif kind.__type == 'Interface' then
  247. return context.schema:getImplementors(kind.name)
  248. elseif kind.__type == 'Union' then
  249. local types = {}
  250. for i = 1, #kind.types do
  251. types[kind.types[i]] = kind.types[i]
  252. end
  253. return types
  254. else
  255. return {}
  256. end
  257. end
  258. local parentTypes = getTypes(parentType)
  259. local fragmentTypes = getTypes(fragmentType)
  260. local valid = util.find(parentTypes, function(kind)
  261. return fragmentTypes[kind]
  262. end)
  263. if not valid then
  264. error('Fragment type condition is not possible for given type')
  265. end
  266. end
  267. function rules.uniqueInputObjectFields(node, context)
  268. local function validateValue(value)
  269. if value.kind == 'listType' or value.kind == 'nonNullType' then
  270. return validateValue(value.type)
  271. elseif value.kind == 'inputObject' then
  272. local fieldMap = {}
  273. for _, field in ipairs(value.values) do
  274. if fieldMap[field.name] then
  275. error('Multiple input object fields named "' .. field.name .. '"')
  276. end
  277. fieldMap[field.name] = true
  278. validateValue(field.value)
  279. end
  280. end
  281. end
  282. validateValue(node.value)
  283. end
  284. function rules.directivesAreDefined(node, context)
  285. if not node.directives then return end
  286. for _, directive in pairs(node.directives) do
  287. if not context.schema:getDirective(directive.name.value) then
  288. error('Unknown directive "' .. directive.name.value .. '"')
  289. end
  290. end
  291. end
  292. function rules.variablesHaveCorrectType(node, context)
  293. local function validateType(type)
  294. if type.kind == 'listType' or type.kind == 'nonNullType' then
  295. validateType(type.type)
  296. elseif type.kind == 'namedType' then
  297. local schemaType = context.schema:getType(type.name.value)
  298. if not schemaType then
  299. error('Variable specifies unknown type "' .. tostring(type.name.value) .. '"')
  300. elseif schemaType.__type ~= 'Scalar' and schemaType.__type ~= 'Enum' and schemaType.__type ~= 'InputObject' then
  301. error('Variable types must be scalars, enums, or input objects, got "' .. schemaType.__type .. '"')
  302. end
  303. end
  304. end
  305. if node.variableDefinitions then
  306. for _, definition in ipairs(node.variableDefinitions) do
  307. validateType(definition.type)
  308. end
  309. end
  310. end
  311. function rules.variableDefaultValuesHaveCorrectType(node, context)
  312. if node.variableDefinitions then
  313. for _, definition in ipairs(node.variableDefinitions) do
  314. if definition.type.kind == 'nonNullType' and definition.defaultValue then
  315. error('Non-null variables can not have default values')
  316. elseif definition.defaultValue then
  317. util.coerceValue(definition.defaultValue, context.schema:getType(definition.type.name.value))
  318. end
  319. end
  320. end
  321. end
  322. function rules.variablesAreUsed(node, context)
  323. if node.variableDefinitions then
  324. for _, definition in ipairs(node.variableDefinitions) do
  325. local variableName = definition.variable.name.value
  326. if not context.variableReferences[variableName] then
  327. error('Unused variable "' .. variableName .. '"')
  328. end
  329. end
  330. end
  331. end
  332. function rules.variablesAreDefined(node, context)
  333. if context.variableReferences then
  334. local variableMap = {}
  335. for _, definition in ipairs(node.variableDefinitions or {}) do
  336. variableMap[definition.variable.name.value] = true
  337. end
  338. for variable in pairs(context.variableReferences) do
  339. if not variableMap[variable] then
  340. error('Unknown variable "' .. variable .. '"')
  341. end
  342. end
  343. end
  344. end
  345. function rules.variableUsageAllowed(node, context)
  346. if context.currentOperation then
  347. local variableMap = {}
  348. for _, definition in ipairs(context.currentOperation.variableDefinitions or {}) do
  349. variableMap[definition.variable.name.value] = definition
  350. end
  351. local arguments
  352. if node.kind == 'field' then
  353. arguments = { [node.name.value] = node.arguments }
  354. elseif node.kind == 'fragmentSpread' then
  355. local seen = {}
  356. local function collectArguments(referencedNode)
  357. if referencedNode.kind == 'selectionSet' then
  358. for _, selection in ipairs(referencedNode.selections) do
  359. if not seen[selection] then
  360. seen[selection] = true
  361. collectArguments(selection)
  362. end
  363. end
  364. elseif referencedNode.kind == 'field' and referencedNode.arguments then
  365. local fieldName = referencedNode.name.value
  366. arguments[fieldName] = arguments[fieldName] or {}
  367. for _, argument in ipairs(referencedNode.arguments) do
  368. table.insert(arguments[fieldName], argument)
  369. end
  370. elseif referencedNode.kind == 'inlineFragment' then
  371. return collectArguments(referencedNode.selectionSet)
  372. elseif referencedNode.kind == 'fragmentSpread' then
  373. local fragment = context.fragmentMap[referencedNode.name.value]
  374. return fragment and collectArguments(fragment.selectionSet)
  375. end
  376. end
  377. local fragment = context.fragmentMap[node.name.value]
  378. if fragment then
  379. arguments = {}
  380. collectArguments(fragment.selectionSet)
  381. end
  382. end
  383. if not arguments then return end
  384. for field in pairs(arguments) do
  385. local parentField = getParentField(context, field)
  386. for i = 1, #arguments[field] do
  387. local argument = arguments[field][i]
  388. if argument.value.kind == 'variable' then
  389. local argumentType = parentField.arguments[argument.name.value]
  390. local variableName = argument.value.name.value
  391. local variableDefinition = variableMap[variableName]
  392. local hasDefault = variableDefinition.defaultValue ~= nil
  393. local function typeFromAST(variable)
  394. local innerType
  395. if variable.kind == 'listType' then
  396. innerType = typeFromAST(variable.type)
  397. return innerType and types.list(innerType)
  398. elseif variable.kind == 'nonNullType' then
  399. innerType = typeFromAST(variable.type)
  400. return innerType and types.nonNull(innerType)
  401. else
  402. assert(variable.kind == 'namedType', 'Variable must be a named type')
  403. return context.schema:getType(variable.name.value)
  404. end
  405. end
  406. local variableType = typeFromAST(variableDefinition.type)
  407. if hasDefault and variableType.__type ~= 'NonNull' then
  408. variableType = types.nonNull(variableType)
  409. end
  410. local function isTypeSubTypeOf(subType, superType)
  411. if subType == superType then return true end
  412. if superType.__type == 'NonNull' then
  413. if subType.__type == 'NonNull' then
  414. return isTypeSubTypeOf(subType.ofType, superType.ofType)
  415. end
  416. return false
  417. elseif subType.__type == 'NonNull' then
  418. return isTypeSubTypeOf(subType.ofType, superType)
  419. end
  420. if superType.__type == 'List' then
  421. if subType.__type == 'List' then
  422. return isTypeSubTypeOf(subType.ofType, superType.ofType)
  423. end
  424. return false
  425. elseif subType.__type == 'List' then
  426. return false
  427. end
  428. if subType.__type ~= 'Object' then return false end
  429. if superType.__type == 'Interface' then
  430. local implementors = context.schema:getImplementors(superType.name)
  431. return implementors and implementors[context.schema:getType(subType.name)]
  432. elseif superType.__type == 'Union' then
  433. local types = superType.types
  434. for i = 1, #types do
  435. if types[i] == subType then
  436. return true
  437. end
  438. end
  439. return false
  440. end
  441. return false
  442. end
  443. if not isTypeSubTypeOf(variableType, argumentType) then
  444. error('Variable type mismatch')
  445. end
  446. end
  447. end
  448. end
  449. end
  450. end
  451. return rules