rules.lua 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. local types = require 'types'
  2. local util = require 'util'
  3. local rules = {}
  4. function rules.uniqueOperationNames(node, context)
  5. local name = node.name and node.name.value
  6. if name then
  7. if context.operationNames[name] then
  8. error('Multiple operations exist named "' .. name .. '"')
  9. end
  10. context.operationNames[name] = true
  11. end
  12. end
  13. function rules.loneAnonymousOperation(node, context)
  14. local name = node.name and node.name.value
  15. if context.hasAnonymousOperation or (not name and next(context.operationNames)) then
  16. error('Cannot have more than one operation when using anonymous operations')
  17. end
  18. if not name then
  19. context.hasAnonymousOperation = true
  20. end
  21. end
  22. function rules.fieldsDefinedOnType(node, context)
  23. if context.objects[#context.objects] == false then
  24. local parent = context.objects[#context.objects - 1]
  25. error('Field "' .. node.name.value .. '" is not defined on type "' .. parent.name .. '"')
  26. end
  27. end
  28. function rules.argumentsDefinedOnType(node, context)
  29. if node.arguments then
  30. local parentField = context.objects[#context.objects - 1].fields[node.name.value]
  31. for _, argument in pairs(node.arguments) do
  32. local name = argument.name.value
  33. if not parentField.arguments[name] then
  34. error('Non-existent argument "' .. name .. '"')
  35. end
  36. end
  37. end
  38. end
  39. function rules.scalarFieldsAreLeaves(node, context)
  40. if context.objects[#context.objects].__type == 'Scalar' and node.selectionSet then
  41. error('Scalar values cannot have subselections')
  42. end
  43. end
  44. function rules.compositeFieldsAreNotLeaves(node, context)
  45. local _type = context.objects[#context.objects].__type
  46. local isCompositeType = _type == 'Object' or _type == 'Interface' or _type == 'Union'
  47. if isCompositeType and not node.selectionSet then
  48. error('Composite types must have subselections')
  49. end
  50. end
  51. function rules.unambiguousSelections(node, context)
  52. local selectionMap = {}
  53. local seen = {}
  54. local function findConflict(entryA, entryB)
  55. -- Parent types can't overlap if they're different objects.
  56. -- Interface and union types may overlap.
  57. if entryA.parent ~= entryB.parent and entryA.__type == 'Object' and entryB.__type == 'Object' then
  58. return
  59. end
  60. -- Error if there are aliases that map two different fields to the same name.
  61. if entryA.field.name.value ~= entryB.field.name.value then
  62. return 'Type name mismatch'
  63. end
  64. -- Error if there are fields with the same name that have different return types.
  65. if entryA.definition and entryB.definition and entryA.definition ~= entryB.definition then
  66. return 'Return type mismatch'
  67. end
  68. -- Error if arguments are not identical for two fields with the same name.
  69. local argsA = entryA.field.arguments or {}
  70. local argsB = entryB.field.arguments or {}
  71. if #argsA ~= #argsB then
  72. return 'Argument mismatch'
  73. end
  74. local argMap = {}
  75. for i = 1, #argsA do
  76. argMap[argsA[i].name.value] = argsA[i].value
  77. end
  78. for i = 1, #argsB do
  79. local name = argsB[i].name.value
  80. if not argMap[name] then
  81. return 'Argument mismatch'
  82. elseif argMap[name].kind ~= argsB[i].value.kind then
  83. return 'Argument mismatch'
  84. elseif argMap[name].value ~= argsB[i].value.value then
  85. return 'Argument mismatch'
  86. end
  87. end
  88. end
  89. local function validateField(key, entry)
  90. if selectionMap[key] then
  91. for i = 1, #selectionMap[key] do
  92. local conflict = findConflict(selectionMap[key][i], entry)
  93. if conflict then
  94. error(conflict)
  95. end
  96. end
  97. table.insert(selectionMap[key], entry)
  98. else
  99. selectionMap[key] = { entry }
  100. end
  101. end
  102. -- Recursively make sure that there are no ambiguous selections with the same name.
  103. local function validateSelectionSet(selectionSet, parentType)
  104. for _, selection in ipairs(selectionSet.selections) do
  105. if selection.kind == 'field' then
  106. if not parentType or not parentType.fields or not parentType.fields[selection.name.value] then return end
  107. local key = selection.alias and selection.alias.name.value or selection.name.value
  108. local definition = parentType.fields[selection.name.value].kind
  109. local fieldEntry = {
  110. parent = parentType,
  111. field = selection,
  112. definition = definition
  113. }
  114. validateField(key, fieldEntry)
  115. elseif selection.kind == 'inlineFragment' then
  116. local parentType = selection.typeCondition and context.schema:getType(selection.typeCondition.name.value) or parentType
  117. validateSelectionSet(selection.selectionSet, parentType)
  118. elseif selection.kind == 'fragmentSpread' then
  119. local fragmentDefinition = context.fragmentMap[selection.name.value]
  120. if not seen[fragmentDefinition] then
  121. seen[fragmentDefinition] = true
  122. if fragmentDefinition and fragmentDefinition.typeCondition then
  123. local parentType = context.schema:getType(fragmentDefinition.typeCondition.name.value)
  124. validateSelectionSet(fragmentDefinition.selectionSet, parentType)
  125. end
  126. end
  127. end
  128. end
  129. end
  130. validateSelectionSet(node, context.objects[#context.objects])
  131. end
  132. function rules.uniqueArgumentNames(node, context)
  133. if node.arguments then
  134. local arguments = {}
  135. for _, argument in ipairs(node.arguments) do
  136. local name = argument.name.value
  137. if arguments[name] then
  138. error('Encountered multiple arguments named "' .. name .. '"')
  139. end
  140. arguments[name] = true
  141. end
  142. end
  143. end
  144. function rules.argumentsOfCorrectType(node, context)
  145. if node.arguments then
  146. local parentField = context.objects[#context.objects - 1].fields[node.name.value]
  147. for _, argument in pairs(node.arguments) do
  148. local name = argument.name.value
  149. local argumentType = parentField.arguments[name]
  150. util.coerceValue(argument.value, argumentType)
  151. end
  152. end
  153. end
  154. function rules.requiredArgumentsPresent(node, context)
  155. local arguments = node.arguments or {}
  156. local parentField = context.objects[#context.objects - 1].fields[node.name.value]
  157. for name, argument in pairs(parentField.arguments) do
  158. if argument.__type == 'NonNull' then
  159. local present = util.find(arguments, function(argument)
  160. return argument.name.value == name
  161. end)
  162. if not present then
  163. error('Required argument "' .. name .. '" was not supplied.')
  164. end
  165. end
  166. end
  167. end
  168. function rules.uniqueFragmentNames(node, context)
  169. local fragments = {}
  170. for _, definition in ipairs(node.definitions) do
  171. if definition.kind == 'fragmentDefinition' then
  172. local name = definition.name.value
  173. if fragments[name] then
  174. error('Encountered multiple fragments named "' .. name .. '"')
  175. end
  176. fragments[name] = true
  177. end
  178. end
  179. end
  180. function rules.fragmentHasValidType(node, context)
  181. if not node.typeCondition then return end
  182. local name = node.typeCondition.name.value
  183. local kind = context.schema:getType(name)
  184. if not kind then
  185. error('Fragment refers to non-existent type "' .. name .. '"')
  186. end
  187. if kind.__type ~= 'Object' and kind.__type ~= 'Interface' and kind.__type ~= 'Union' then
  188. error('Fragment type must be an Object, Interface, or Union, got ' .. kind.__type)
  189. end
  190. end
  191. function rules.noUnusedFragments(node, context)
  192. for _, definition in ipairs(node.definitions) do
  193. if definition.kind == 'fragmentDefinition' then
  194. local name = definition.name.value
  195. if not context.usedFragments[name] then
  196. error('Fragment "' .. name .. '" was not used.')
  197. end
  198. end
  199. end
  200. end
  201. function rules.fragmentSpreadTargetDefined(node, context)
  202. if not context.fragmentMap[node.name.value] then
  203. error('Fragment spread refers to non-existent fragment "' .. node.name.value .. '"')
  204. end
  205. end
  206. function rules.fragmentDefinitionHasNoCycles(node, context)
  207. local seen = { [node.name.value] = true }
  208. local function detectCycles(selectionSet)
  209. for _, selection in ipairs(selectionSet.selections) do
  210. if selection.kind == 'inlineFragment' then
  211. detectCycles(selection.selectionSet)
  212. elseif selection.kind == 'fragmentSpread' then
  213. if seen[selection.name.value] then
  214. error('Fragment definition has cycles')
  215. end
  216. seen[selection.name.value] = true
  217. local fragmentDefinition = context.fragmentMap[selection.name.value]
  218. if fragmentDefinition and fragmentDefinition.typeCondition then
  219. detectCycles(fragmentDefinition.selectionSet)
  220. end
  221. end
  222. end
  223. end
  224. detectCycles(node.selectionSet)
  225. end
  226. function rules.fragmentSpreadIsPossible(node, context)
  227. local fragment = node.kind == 'inlineFragment' and node or context.fragmentMap[node.name.value]
  228. local parentType = context.objects[#context.objects - 1]
  229. local fragmentType
  230. if node.kind == 'inlineFragment' then
  231. fragmentType = node.typeCondition and context.schema:getType(node.typeCondition.name.value) or parentType
  232. else
  233. fragmentType = context.schema:getType(fragment.typeCondition.name.value)
  234. end
  235. -- Some types are not present in the schema. Let other rules handle this.
  236. if not parentType or not fragmentType then return end
  237. local function getTypes(kind)
  238. if kind.__type == 'Object' then
  239. return { [kind] = kind }
  240. elseif kind.__type == 'Interface' then
  241. return context.schema:getImplementors(kind.name)
  242. elseif kind.__type == 'Union' then
  243. local types = {}
  244. for i = 1, #kind.types do
  245. types[kind.types[i]] = kind.types[i]
  246. end
  247. return types
  248. end
  249. end
  250. local parentTypes = getTypes(parentType)
  251. local fragmentTypes = getTypes(fragmentType)
  252. local valid = util.find(parentTypes, function(kind)
  253. return fragmentTypes[kind]
  254. end)
  255. if not valid then
  256. error('Fragment type condition is not possible for given type')
  257. end
  258. end
  259. function rules.uniqueInputObjectFields(node, context)
  260. local function validateValue(value)
  261. if value.kind == 'listType' or value.kind == 'nonNullType' then
  262. return validateValue(value.type)
  263. elseif value.kind == 'inputObject' then
  264. local fieldMap = {}
  265. for _, field in ipairs(value.values) do
  266. if fieldMap[field.name] then
  267. error('Multiple input object fields named "' .. field.name .. '"')
  268. end
  269. fieldMap[field.name] = true
  270. validateValue(field.value)
  271. end
  272. end
  273. end
  274. validateValue(node.value)
  275. end
  276. function rules.directivesAreDefined(node, context)
  277. if not node.directives then return end
  278. for _, directive in pairs(node.directives) do
  279. if not context.schema:getDirective(directive.name.value) then
  280. error('Unknown directive "' .. directive.name.value .. '"')
  281. end
  282. end
  283. end
  284. function rules.variablesHaveCorrectType(node, context)
  285. local function validateType(type)
  286. if type.kind == 'listType' or type.kind == 'nonNullType' then
  287. validateType(type.type)
  288. elseif type.kind == 'namedType' then
  289. local schemaType = context.schema:getType(type.name.value)
  290. if not schemaType then
  291. error('Variable specifies unknown type "' .. tostring(type.name.value) .. '"')
  292. elseif schemaType.__type ~= 'Scalar' and schemaType.__type ~= 'Enum' and schemaType.__type ~= 'InputObject' then
  293. error('Variable types must be scalars, enums, or input objects, got "' .. schemaType.__type .. '"')
  294. end
  295. end
  296. end
  297. if node.variableDefinitions then
  298. for _, definition in ipairs(node.variableDefinitions) do
  299. validateType(definition.type)
  300. end
  301. end
  302. end
  303. function rules.variableDefaultValuesHaveCorrectType(node, context)
  304. if node.variableDefinitions then
  305. for _, definition in ipairs(node.variableDefinitions) do
  306. if definition.type.kind == 'nonNullType' and definition.defaultValue then
  307. error('Non-null variables can not have default values')
  308. elseif definition.defaultValue then
  309. util.coerceValue(definition.defaultValue, context.schema:getType(definition.type.name.value))
  310. end
  311. end
  312. end
  313. end
  314. function rules.variablesAreUsed(node, context)
  315. if node.variableDefinitions then
  316. for _, definition in ipairs(node.variableDefinitions) do
  317. local variableName = definition.variable.name.value
  318. if not context.variableReferences[variableName] then
  319. error('Unused variable "' .. variableName .. '"')
  320. end
  321. end
  322. end
  323. end
  324. function rules.variablesAreDefined(node, context)
  325. if context.variableReferences then
  326. local variableMap = {}
  327. for _, definition in ipairs(node.variableDefinitions or {}) do
  328. variableMap[definition.variable.name.value] = true
  329. end
  330. for variable in pairs(context.variableReferences) do
  331. if not variableMap[variable] then
  332. error('Unknown variable "' .. variable .. '"')
  333. end
  334. end
  335. end
  336. end
  337. function rules.variableUsageAllowed(node, context)
  338. if context.currentOperation then
  339. local variableMap = {}
  340. for _, definition in ipairs(context.currentOperation.variableDefinitions or {}) do
  341. variableMap[definition.variable.name.value] = definition
  342. end
  343. local arguments
  344. if node.kind == 'field' then
  345. arguments = { [node.name.value] = node.arguments }
  346. elseif node.kind == 'fragmentSpread' then
  347. local seen = {}
  348. local function collectArguments(referencedNode)
  349. if referencedNode.kind == 'selectionSet' then
  350. for _, selection in ipairs(referencedNode.selections) do
  351. if not seen[selection] then
  352. seen[selection] = true
  353. collectArguments(selection)
  354. end
  355. end
  356. elseif referencedNode.kind == 'field' and referencedNode.arguments then
  357. local fieldName = referencedNode.name.value
  358. arguments[fieldName] = arguments[fieldName] or {}
  359. for _, argument in ipairs(referencedNode.arguments) do
  360. table.insert(arguments[fieldName], argument)
  361. end
  362. elseif referencedNode.kind == 'inlineFragment' then
  363. return collectArguments(referencedNode.selectionSet)
  364. elseif referencedNode.kind == 'fragmentSpread' then
  365. local fragment = context.fragmentMap[referencedNode.name.value]
  366. return fragment and collectArguments(fragment.selectionSet)
  367. end
  368. end
  369. local fragment = context.fragmentMap[node.name.value]
  370. if fragment then
  371. arguments = {}
  372. collectArguments(fragment.selectionSet)
  373. end
  374. end
  375. if not arguments then return end
  376. for field in pairs(arguments) do
  377. local parentField = context.objects[#context.objects - 1].fields[field]
  378. for i = 1, #arguments[field] do
  379. local argument = arguments[field][i]
  380. if argument.value.kind == 'variable' then
  381. local argumentType = parentField.arguments[argument.name.value]
  382. local variableName = argument.value.name.value
  383. local variableDefinition = variableMap[variableName]
  384. local hasDefault = variableDefinition.defaultValue ~= nil
  385. local function typeFromAST(variable)
  386. local innerType
  387. if variable.kind == 'listType' then
  388. innerType = typeFromAST(variable.type)
  389. return innerType and types.list(innerType)
  390. elseif variable.kind == 'nonNullType' then
  391. innerType = typeFromAST(variable.type)
  392. return innerType and types.nonNull(innerType)
  393. else
  394. assert(variable.kind == 'namedType', 'Variable must be a named type')
  395. return context.schema:getType(variable.name.value)
  396. end
  397. end
  398. local variableType = typeFromAST(variableDefinition.type)
  399. if hasDefault and variableType.__type ~= 'NonNull' then
  400. variableType = types.nonNull(variableType)
  401. end
  402. local function isTypeSubTypeOf(subType, superType)
  403. if subType == superType then return true end
  404. if superType.__type == 'NonNull' then
  405. if subType.__type == 'NonNull' then
  406. return isTypeSubTypeOf(subType.ofType, superType.ofType)
  407. end
  408. return false
  409. elseif subType.__type == 'NonNull' then
  410. return typeIsSubTypeOf(subType.ofType, superType)
  411. end
  412. if superType.__type == 'List' then
  413. if subType.__type == 'List' then
  414. return isTypeSubTypeOf(subType.ofType, superType.ofType)
  415. end
  416. return false
  417. elseif subType.__type == 'List' then
  418. return false
  419. end
  420. if subType.__type ~= 'Object' then return false end
  421. if superType.__type == 'Interface' then
  422. local implementors = context.schema:getImplementors(superType.name)
  423. return implementors and implementors[context.schema:getType(subType.name)]
  424. elseif superType.__type == 'Union' then
  425. local types = superType.types
  426. for i = 1, #types do
  427. if types[i] == subType then
  428. return true
  429. end
  430. end
  431. return false
  432. end
  433. return false
  434. end
  435. if not isTypeSubTypeOf(variableType, argumentType) then
  436. error('Variable type mismatch')
  437. end
  438. end
  439. end
  440. end
  441. end
  442. end
  443. return rules