rules.lua 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579
  1. local types = require 'types'
  2. local rules = {}
  3. local function checkValueMatchesType(valueNode, schemaType)
  4. if schemaType.__type == 'NonNull' then
  5. return checkValueMatchesType(schemaType.ofType, valueNode)
  6. end
  7. if schemaType.__type == 'List' then
  8. if valueNode.kind ~= 'list' then
  9. error('Expected a list')
  10. end
  11. for i = 1, #valueNode.values do
  12. checkValueMatchesType(schemaType.ofType, valueNode.values[i])
  13. end
  14. end
  15. if schemaType.__type == 'InputObject' then
  16. if valueNode.kind ~= 'inputObject' then
  17. error('Expected an input object')
  18. end
  19. for _, field in ipairs(valueNode.values) do
  20. if not schemaType.fields[field.name] then
  21. error('Unknown input object field "' .. field.name .. '"')
  22. end
  23. checkValueMatchesType(schemaType.fields[field.name].kind, field.value)
  24. end
  25. end
  26. if schemaType.__type == 'Enum' then
  27. if valueNode.kind ~= 'enum' then
  28. error('Expected enum value, got ' .. valueNode.kind)
  29. end
  30. if not schemaType.values[valueNode.value] then
  31. error('Invalid enum value "' .. valueNode.value .. '"')
  32. end
  33. end
  34. if schemaType.__type == 'Scalar' then
  35. if schemaType.parseLiteral(valueNode) == nil then
  36. error('Could not coerce "' .. valueNode.value .. '" to "' .. schemaType.name .. '"')
  37. end
  38. end
  39. end
  40. function rules.uniqueOperationNames(node, context)
  41. local name = node.name and node.name.value
  42. if name then
  43. if context.operationNames[name] then
  44. error('Multiple operations exist named "' .. name .. '"')
  45. end
  46. context.operationNames[name] = true
  47. end
  48. end
  49. function rules.loneAnonymousOperation(node, context)
  50. local name = node.name and node.name.value
  51. if context.hasAnonymousOperation or (not name and next(context.operationNames)) then
  52. error('Cannot have more than one operation when using anonymous operations')
  53. end
  54. if not name then
  55. context.hasAnonymousOperation = true
  56. end
  57. end
  58. function rules.fieldsDefinedOnType(node, context)
  59. if context.objects[#context.objects] == false then
  60. local parent = context.objects[#context.objects - 1]
  61. error('Field "' .. node.name.value .. '" is not defined on type "' .. parent.name .. '"')
  62. end
  63. end
  64. function rules.argumentsDefinedOnType(node, context)
  65. if node.arguments then
  66. local parentField = context.objects[#context.objects - 1].fields[node.name.value]
  67. for _, argument in pairs(node.arguments) do
  68. local name = argument.name.value
  69. if not parentField.arguments[name] then
  70. error('Non-existent argument "' .. name .. '"')
  71. end
  72. end
  73. end
  74. end
  75. function rules.scalarFieldsAreLeaves(node, context)
  76. if context.objects[#context.objects].__type == 'Scalar' and node.selectionSet then
  77. error('Scalar values cannot have subselections')
  78. end
  79. end
  80. function rules.compositeFieldsAreNotLeaves(node, context)
  81. local _type = context.objects[#context.objects].__type
  82. local isCompositeType = _type == 'Object' or _type == 'Interface' or _type == 'Union'
  83. if isCompositeType and not node.selectionSet then
  84. error('Composite types must have subselections')
  85. end
  86. end
  87. function rules.unambiguousSelections(node, context)
  88. local selectionMap = {}
  89. local seen = {}
  90. local function findConflict(entryA, entryB)
  91. -- Parent types can't overlap if they're different objects.
  92. -- Interface and union types may overlap.
  93. if entryA.parent ~= entryB.parent and entryA.__type == 'Object' and entryB.__type == 'Object' then
  94. return
  95. end
  96. -- Error if there are aliases that map two different fields to the same name.
  97. if entryA.field.name.value ~= entryB.field.name.value then
  98. return 'Type name mismatch'
  99. end
  100. -- Error if there are fields with the same name that have different return types.
  101. if entryA.definition and entryB.definition and entryA.definition ~= entryB.definition then
  102. return 'Return type mismatch'
  103. end
  104. -- Error if arguments are not identical for two fields with the same name.
  105. local argsA = entryA.field.arguments or {}
  106. local argsB = entryB.field.arguments or {}
  107. if #argsA ~= #argsB then
  108. return 'Argument mismatch'
  109. end
  110. local argMap = {}
  111. for i = 1, #argsA do
  112. argMap[argsA[i].name.value] = argsA[i].value
  113. end
  114. for i = 1, #argsB do
  115. local name = argsB[i].name.value
  116. if not argMap[name] then
  117. return 'Argument mismatch'
  118. elseif argMap[name].kind ~= argsB[i].value.kind then
  119. return 'Argument mismatch'
  120. elseif argMap[name].value ~= argsB[i].value.value then
  121. return 'Argument mismatch'
  122. end
  123. end
  124. end
  125. local function validateField(key, entry)
  126. if selectionMap[key] then
  127. for i = 1, #selectionMap[key] do
  128. local conflict = findConflict(selectionMap[key][i], entry)
  129. if conflict then
  130. error(conflict)
  131. end
  132. end
  133. table.insert(selectionMap[key], entry)
  134. else
  135. selectionMap[key] = { entry }
  136. end
  137. end
  138. -- Recursively make sure that there are no ambiguous selections with the same name.
  139. local function validateSelectionSet(selectionSet, parentType)
  140. for _, selection in ipairs(selectionSet.selections) do
  141. if selection.kind == 'field' then
  142. if not parentType or not parentType.fields[selection.name.value] then return end
  143. local key = selection.alias and selection.alias.name.value or selection.name.value
  144. local definition = parentType.fields[selection.name.value].kind
  145. local fieldEntry = {
  146. parent = parentType,
  147. field = selection,
  148. definition = definition
  149. }
  150. if seen[definition] then
  151. return
  152. end
  153. seen[definition] = true
  154. validateField(key, fieldEntry)
  155. elseif selection.kind == 'inlineFragment' then
  156. local parentType = selection.typeCondition and context.schema:getType(selection.typeCondition.name.value) or parentType
  157. validateSelectionSet(selection.selectionSet, parentType)
  158. elseif selection.kind == 'fragmentSpread' then
  159. local fragmentDefinition = context.fragmentMap[selection.name.value]
  160. if fragmentDefinition and fragmentDefinition.typeCondition then
  161. local parentType = context.schema:getType(fragmentDefinition.typeCondition.name.value)
  162. validateSelectionSet(fragmentDefinition.selectionSet, parentType)
  163. end
  164. end
  165. end
  166. end
  167. validateSelectionSet(node, context.objects[#context.objects])
  168. end
  169. function rules.uniqueArgumentNames(node, context)
  170. if node.arguments then
  171. local arguments = {}
  172. for _, argument in ipairs(node.arguments) do
  173. local name = argument.name.value
  174. if arguments[name] then
  175. error('Encountered multiple arguments named "' .. name .. '"')
  176. end
  177. arguments[name] = true
  178. end
  179. end
  180. end
  181. function rules.argumentsOfCorrectType(node, context)
  182. if node.arguments then
  183. local parentField = context.objects[#context.objects - 1].fields[node.name.value]
  184. for _, argument in pairs(node.arguments) do
  185. local name = argument.name.value
  186. local argumentType = parentField.arguments[name]
  187. checkValueMatchesType(argument.value, argumentType)
  188. end
  189. end
  190. end
  191. function rules.requiredArgumentsPresent(node, context)
  192. local arguments = node.arguments or {}
  193. local parentField = context.objects[#context.objects - 1].fields[node.name.value]
  194. for name, argument in pairs(parentField.arguments) do
  195. if argument.__type == 'NonNull' then
  196. local present = false
  197. for i = 1, #arguments do
  198. if arguments[i].name.value == name then
  199. present = true
  200. break
  201. end
  202. end
  203. if not present then
  204. error('Required argument "' .. name .. '" was not supplied.')
  205. end
  206. end
  207. end
  208. end
  209. function rules.uniqueFragmentNames(node, context)
  210. local fragments = {}
  211. for _, definition in ipairs(node.definitions) do
  212. if definition.kind == 'fragmentDefinition' then
  213. local name = definition.name.value
  214. if fragments[name] then
  215. error('Encountered multiple fragments named "' .. name .. '"')
  216. end
  217. fragments[name] = true
  218. end
  219. end
  220. end
  221. function rules.fragmentHasValidType(node, context)
  222. if not node.typeCondition then return end
  223. local name = node.typeCondition.name.value
  224. local kind = context.schema:getType(name)
  225. if not kind then
  226. error('Fragment refers to non-existent type "' .. name .. '"')
  227. end
  228. if kind.__type ~= 'Object' and kind.__type ~= 'Interface' and kind.__type ~= 'Union' then
  229. error('Fragment type must be an Object, Interface, or Union, got ' .. kind.__type)
  230. end
  231. end
  232. function rules.noUnusedFragments(node, context)
  233. for _, definition in ipairs(node.definitions) do
  234. if definition.kind == 'fragmentDefinition' then
  235. local name = definition.name.value
  236. if not context.usedFragments[name] then
  237. error('Fragment "' .. name .. '" was not used.')
  238. end
  239. end
  240. end
  241. end
  242. function rules.fragmentSpreadTargetDefined(node, context)
  243. if not context.fragmentMap[node.name.value] then
  244. error('Fragment spread refers to non-existent fragment "' .. node.name.value .. '"')
  245. end
  246. end
  247. function rules.fragmentDefinitionHasNoCycles(node, context)
  248. local seen = { [node.name.value] = true }
  249. local function detectCycles(selectionSet)
  250. for _, selection in ipairs(selectionSet.selections) do
  251. if selection.kind == 'inlineFragment' then
  252. detectCycles(selection.selectionSet)
  253. elseif selection.kind == 'fragmentSpread' then
  254. if seen[selection.name.value] then
  255. error('Fragment definition has cycles')
  256. end
  257. seen[selection.name.value] = true
  258. local fragmentDefinition = context.fragmentMap[selection.name.value]
  259. if fragmentDefinition and fragmentDefinition.typeCondition then
  260. detectCycles(fragmentDefinition.selectionSet)
  261. end
  262. end
  263. end
  264. end
  265. detectCycles(node.selectionSet)
  266. end
  267. function rules.fragmentSpreadIsPossible(node, context)
  268. local fragment = node.kind == 'inlineFragment' and node or context.fragmentMap[node.name.value]
  269. local parentType = context.objects[#context.objects - 1]
  270. local fragmentType
  271. if node.kind == 'inlineFragment' then
  272. fragmentType = node.typeCondition and context.schema:getType(node.typeCondition.name.value) or parentType
  273. else
  274. fragmentType = context.schema:getType(fragment.typeCondition.name.value)
  275. end
  276. -- Some types are not present in the schema. Let other rules handle this.
  277. if not parentType or not fragmentType then return end
  278. local function getTypes(kind)
  279. if kind.__type == 'Object' then
  280. return { [kind] = kind }
  281. elseif kind.__type == 'Interface' then
  282. return context.schema:getImplementors(kind.name)
  283. elseif kind.__type == 'Union' then
  284. local types = {}
  285. for i = 1, #kind.types do
  286. types[kind.types[i]] = kind.types[i]
  287. end
  288. return types
  289. end
  290. end
  291. local parentTypes = getTypes(parentType)
  292. local fragmentTypes = getTypes(fragmentType)
  293. local valid = false
  294. for _, kind in pairs(parentTypes) do
  295. if fragmentTypes[kind] then
  296. valid = true
  297. break
  298. end
  299. end
  300. if not valid then
  301. error('Fragment type condition is not possible for given type')
  302. end
  303. end
  304. function rules.uniqueInputObjectFields(node, context)
  305. local function validateValue(value)
  306. if value.kind == 'listType' or value.kind == 'nonNullType' then
  307. return validateValue(value.type)
  308. elseif value.kind == 'inputObject' then
  309. local fieldMap = {}
  310. for _, field in ipairs(value.values) do
  311. if fieldMap[field.name] then
  312. error('Multiple input object fields named "' .. field.name .. '"')
  313. end
  314. fieldMap[field.name] = true
  315. validateValue(field.value)
  316. end
  317. end
  318. end
  319. validateValue(node.value)
  320. end
  321. function rules.directivesAreDefined(node, context)
  322. if not node.directives then return end
  323. for _, directive in pairs(node.directives) do
  324. if not context.schema:getDirective(directive.name.value) then
  325. error('Unknown directive "' .. directive.name.value .. '"')
  326. end
  327. end
  328. end
  329. function rules.variablesHaveCorrectType(node, context)
  330. local function validateType(type)
  331. if type.kind == 'listType' or type.kind == 'nonNullType' then
  332. validateType(type.type)
  333. elseif type.kind == 'namedType' then
  334. local schemaType = context.schema:getType(type.name.value)
  335. if not schemaType then
  336. error('Variable specifies unknown type "' .. tostring(type.name.value) .. '"')
  337. elseif schemaType.__type ~= 'Scalar' and schemaType.__type ~= 'Enum' and schemaType.__type ~= 'InputObject' then
  338. error('Variable types must be scalars, enums, or input objects, got "' .. schemaType.__type .. '"')
  339. end
  340. end
  341. end
  342. if node.variableDefinitions then
  343. for _, definition in ipairs(node.variableDefinitions) do
  344. validateType(definition.type)
  345. end
  346. end
  347. end
  348. function rules.variableDefaultValuesHaveCorrectType(node, context)
  349. if node.variableDefinitions then
  350. for _, definition in ipairs(node.variableDefinitions) do
  351. if definition.type.kind == 'nonNullType' and definition.defaultValue then
  352. error('Non-null variables can not have default values')
  353. elseif definition.defaultValue then
  354. checkValueMatchesType(definition.defaultValue, context.schema:getType(definition.type.name.value))
  355. end
  356. end
  357. end
  358. end
  359. function rules.variablesAreUsed(node, context)
  360. if node.variableDefinitions then
  361. for _, definition in ipairs(node.variableDefinitions) do
  362. local variableName = definition.variable.name.value
  363. if not context.variableReferences[variableName] then
  364. error('Unused variable "' .. variableName .. '"')
  365. end
  366. end
  367. end
  368. end
  369. function rules.variablesAreDefined(node, context)
  370. if context.variableReferences then
  371. local variableMap = {}
  372. for _, definition in ipairs(node.variableDefinitions or {}) do
  373. variableMap[definition.variable.name.value] = true
  374. end
  375. for variable in pairs(context.variableReferences) do
  376. if not variableMap[variable] then
  377. error('Unknown variable "' .. variable .. '"')
  378. end
  379. end
  380. end
  381. end
  382. function rules.variableUsageAllowed(node, context)
  383. if context.currentOperation then
  384. local variableMap = {}
  385. for _, definition in ipairs(context.currentOperation.variableDefinitions or {}) do
  386. variableMap[definition.variable.name.value] = definition
  387. end
  388. local arguments
  389. if node.kind == 'field' then
  390. arguments = { [node.name.value] = node.arguments }
  391. elseif node.kind == 'fragmentSpread' then
  392. local function collectArguments(referencedNode)
  393. if referencedNode.kind == 'selectionSet' then
  394. for _, selection in ipairs(referencedNode.selections) do
  395. collectArguments(selection)
  396. end
  397. elseif referencedNode.kind == 'field' and referencedNode.arguments then
  398. local fieldName = referencedNode.name.value
  399. arguments[fieldName] = arguments[fieldName] or {}
  400. for _, argument in ipairs(referencedNode.arguments) do
  401. table.insert(arguments[fieldName], argument)
  402. end
  403. elseif referencedNode.kind == 'inlineFragment' then
  404. return collectArguments(referencedNode.selectionSet)
  405. elseif referencedNode.kind == 'fragmentSpread' then
  406. local fragment = context.fragmentMap[referencedNode.name.value]
  407. return fragment and collectArguments(fragment.selectionSet)
  408. end
  409. end
  410. local fragment = context.fragmentMap[node.name.value]
  411. if fragment then
  412. arguments = {}
  413. collectArguments(fragment.selectionSet)
  414. end
  415. end
  416. if not arguments then return end
  417. for field in pairs(arguments) do
  418. local parentField = context.objects[#context.objects - 1].fields[field]
  419. for i = 1, #arguments[field] do
  420. local argument = arguments[field][i]
  421. if argument.value.kind == 'variable' then
  422. local argumentType = parentField.arguments[argument.name.value]
  423. local variableName = argument.value.name.value
  424. local variableDefinition = variableMap[variableName]
  425. local hasDefault = variableDefinition.defaultValue ~= nil
  426. local function typeFromAST(variable)
  427. local innerType
  428. if variable.kind == 'listType' then
  429. innerType = typeFromAST(variable.type)
  430. return innerType and types.list(innerType)
  431. elseif variable.kind == 'nonNullType' then
  432. innerType = typeFromAST(variable.type)
  433. return innerType and types.nonNull(innerType)
  434. else
  435. assert(variable.kind == 'namedType', 'Variable must be a named type')
  436. return context.schema:getType(variable.name.value)
  437. end
  438. end
  439. local variableType = typeFromAST(variableDefinition.type)
  440. if hasDefault and variableType.__type ~= 'NonNull' then
  441. variableType = types.nonNull(variableType)
  442. end
  443. local function isTypeSubTypeOf(subType, superType)
  444. if subType == superType then return true end
  445. if superType.__type == 'NonNull' then
  446. if subType.__type == 'NonNull' then
  447. return isTypeSubTypeOf(subType.ofType, superType.ofType)
  448. end
  449. return false
  450. elseif subType.__type == 'NonNull' then
  451. return typeIsSubTypeOf(subType.ofType, superType)
  452. end
  453. if superType.__type == 'List' then
  454. if subType.__type == 'List' then
  455. return isTypeSubTypeOf(subType.ofType, superType.ofType)
  456. end
  457. return false
  458. elseif subType.__type == 'List' then
  459. return false
  460. end
  461. if subType.__type ~= 'Object' then return false end
  462. if superType.__type == 'Interface' then
  463. local implementors = context.schema:getImplementors(superType.name)
  464. return implementors and implementors[context.schema:getType(subType.name)]
  465. elseif superType.__type == 'Union' then
  466. local types = superType.types
  467. for i = 1, #types do
  468. if types[i] == subType then
  469. return true
  470. end
  471. end
  472. return false
  473. end
  474. return false
  475. end
  476. if not isTypeSubTypeOf(variableType, argumentType) then
  477. error('Variable type mismatch')
  478. end
  479. end
  480. end
  481. end
  482. end
  483. end
  484. return rules