rules.lua 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. local path = (...):gsub('%.[^%.]+$', '')
  2. local types = require(path .. '.types')
  3. local util = require(path .. '.util')
  4. local rules = {}
  5. function rules.uniqueOperationNames(node, context)
  6. local name = node.name and node.name.value
  7. if name then
  8. if context.operationNames[name] then
  9. error('Multiple operations exist named "' .. name .. '"')
  10. end
  11. context.operationNames[name] = true
  12. end
  13. end
  14. function rules.loneAnonymousOperation(node, context)
  15. local name = node.name and node.name.value
  16. if context.hasAnonymousOperation or (not name and next(context.operationNames)) then
  17. error('Cannot have more than one operation when using anonymous operations')
  18. end
  19. if not name then
  20. context.hasAnonymousOperation = true
  21. end
  22. end
  23. function rules.fieldsDefinedOnType(node, context)
  24. if context.objects[#context.objects] == false then
  25. local parent = context.objects[#context.objects - 1]
  26. error('Field "' .. node.name.value .. '" is not defined on type "' .. parent.name .. '"')
  27. end
  28. end
  29. function rules.argumentsDefinedOnType(node, context)
  30. if node.arguments then
  31. local parentField = context.objects[#context.objects - 1].fields[node.name.value]
  32. for _, argument in pairs(node.arguments) do
  33. local name = argument.name.value
  34. if not parentField.arguments[name] then
  35. error('Non-existent argument "' .. name .. '"')
  36. end
  37. end
  38. end
  39. end
  40. function rules.scalarFieldsAreLeaves(node, context)
  41. if context.objects[#context.objects].__type == 'Scalar' and node.selectionSet then
  42. error('Scalar values cannot have subselections')
  43. end
  44. end
  45. function rules.compositeFieldsAreNotLeaves(node, context)
  46. local _type = context.objects[#context.objects].__type
  47. local isCompositeType = _type == 'Object' or _type == 'Interface' or _type == 'Union'
  48. if isCompositeType and not node.selectionSet then
  49. error('Composite types must have subselections')
  50. end
  51. end
  52. function rules.unambiguousSelections(node, context)
  53. local selectionMap = {}
  54. local seen = {}
  55. local function findConflict(entryA, entryB)
  56. -- Parent types can't overlap if they're different objects.
  57. -- Interface and union types may overlap.
  58. if entryA.parent ~= entryB.parent and entryA.__type == 'Object' and entryB.__type == 'Object' then
  59. return
  60. end
  61. -- Error if there are aliases that map two different fields to the same name.
  62. if entryA.field.name.value ~= entryB.field.name.value then
  63. return 'Type name mismatch'
  64. end
  65. -- Error if there are fields with the same name that have different return types.
  66. if entryA.definition and entryB.definition and entryA.definition ~= entryB.definition then
  67. return 'Return type mismatch'
  68. end
  69. -- Error if arguments are not identical for two fields with the same name.
  70. local argsA = entryA.field.arguments or {}
  71. local argsB = entryB.field.arguments or {}
  72. if #argsA ~= #argsB then
  73. return 'Argument mismatch'
  74. end
  75. local argMap = {}
  76. for i = 1, #argsA do
  77. argMap[argsA[i].name.value] = argsA[i].value
  78. end
  79. for i = 1, #argsB do
  80. local name = argsB[i].name.value
  81. if not argMap[name] then
  82. return 'Argument mismatch'
  83. elseif argMap[name].kind ~= argsB[i].value.kind then
  84. return 'Argument mismatch'
  85. elseif argMap[name].value ~= argsB[i].value.value then
  86. return 'Argument mismatch'
  87. end
  88. end
  89. end
  90. local function validateField(key, entry)
  91. if selectionMap[key] then
  92. for i = 1, #selectionMap[key] do
  93. local conflict = findConflict(selectionMap[key][i], entry)
  94. if conflict then
  95. error(conflict)
  96. end
  97. end
  98. table.insert(selectionMap[key], entry)
  99. else
  100. selectionMap[key] = { entry }
  101. end
  102. end
  103. -- Recursively make sure that there are no ambiguous selections with the same name.
  104. local function validateSelectionSet(selectionSet, parentType)
  105. for _, selection in ipairs(selectionSet.selections) do
  106. if selection.kind == 'field' then
  107. if not parentType or not parentType.fields or not parentType.fields[selection.name.value] then return end
  108. local key = selection.alias and selection.alias.name.value or selection.name.value
  109. local definition = parentType.fields[selection.name.value].kind
  110. local fieldEntry = {
  111. parent = parentType,
  112. field = selection,
  113. definition = definition
  114. }
  115. validateField(key, fieldEntry)
  116. elseif selection.kind == 'inlineFragment' then
  117. local parentType = selection.typeCondition and context.schema:getType(selection.typeCondition.name.value) or parentType
  118. validateSelectionSet(selection.selectionSet, parentType)
  119. elseif selection.kind == 'fragmentSpread' then
  120. local fragmentDefinition = context.fragmentMap[selection.name.value]
  121. if fragmentDefinition and not seen[fragmentDefinition] then
  122. seen[fragmentDefinition] = true
  123. if fragmentDefinition and fragmentDefinition.typeCondition then
  124. local parentType = context.schema:getType(fragmentDefinition.typeCondition.name.value)
  125. validateSelectionSet(fragmentDefinition.selectionSet, parentType)
  126. end
  127. end
  128. end
  129. end
  130. end
  131. validateSelectionSet(node, context.objects[#context.objects])
  132. end
  133. function rules.uniqueArgumentNames(node, context)
  134. if node.arguments then
  135. local arguments = {}
  136. for _, argument in ipairs(node.arguments) do
  137. local name = argument.name.value
  138. if arguments[name] then
  139. error('Encountered multiple arguments named "' .. name .. '"')
  140. end
  141. arguments[name] = true
  142. end
  143. end
  144. end
  145. function rules.argumentsOfCorrectType(node, context)
  146. if node.arguments then
  147. local parentField = context.objects[#context.objects - 1].fields[node.name.value]
  148. for _, argument in pairs(node.arguments) do
  149. local name = argument.name.value
  150. local argumentType = parentField.arguments[name]
  151. util.coerceValue(argument.value, argumentType)
  152. end
  153. end
  154. end
  155. function rules.requiredArgumentsPresent(node, context)
  156. local arguments = node.arguments or {}
  157. local parentField
  158. if context.objects[#context.objects - 1].__type == 'List' then
  159. parentField = context.objects[#context.objects - 2].fields[node.name.value]
  160. else
  161. parentField = context.objects[#context.objects - 1].fields[node.name.value]
  162. end
  163. for name, argument in pairs(parentField.arguments) do
  164. if argument.__type == 'NonNull' then
  165. local present = util.find(arguments, function(argument)
  166. return argument.name.value == name
  167. end)
  168. if not present then
  169. error('Required argument "' .. name .. '" was not supplied.')
  170. end
  171. end
  172. end
  173. end
  174. function rules.uniqueFragmentNames(node, context)
  175. local fragments = {}
  176. for _, definition in ipairs(node.definitions) do
  177. if definition.kind == 'fragmentDefinition' then
  178. local name = definition.name.value
  179. if fragments[name] then
  180. error('Encountered multiple fragments named "' .. name .. '"')
  181. end
  182. fragments[name] = true
  183. end
  184. end
  185. end
  186. function rules.fragmentHasValidType(node, context)
  187. if not node.typeCondition then return end
  188. local name = node.typeCondition.name.value
  189. local kind = context.schema:getType(name)
  190. if not kind then
  191. error('Fragment refers to non-existent type "' .. name .. '"')
  192. end
  193. if kind.__type ~= 'Object' and kind.__type ~= 'Interface' and kind.__type ~= 'Union' then
  194. error('Fragment type must be an Object, Interface, or Union, got ' .. kind.__type)
  195. end
  196. end
  197. function rules.noUnusedFragments(node, context)
  198. for _, definition in ipairs(node.definitions) do
  199. if definition.kind == 'fragmentDefinition' then
  200. local name = definition.name.value
  201. if not context.usedFragments[name] then
  202. error('Fragment "' .. name .. '" was not used.')
  203. end
  204. end
  205. end
  206. end
  207. function rules.fragmentSpreadTargetDefined(node, context)
  208. if not context.fragmentMap[node.name.value] then
  209. error('Fragment spread refers to non-existent fragment "' .. node.name.value .. '"')
  210. end
  211. end
  212. function rules.fragmentDefinitionHasNoCycles(node, context)
  213. local seen = { [node.name.value] = true }
  214. local function detectCycles(selectionSet)
  215. for _, selection in ipairs(selectionSet.selections) do
  216. if selection.kind == 'inlineFragment' then
  217. detectCycles(selection.selectionSet)
  218. elseif selection.kind == 'fragmentSpread' then
  219. if seen[selection.name.value] then
  220. error('Fragment definition has cycles')
  221. end
  222. seen[selection.name.value] = true
  223. local fragmentDefinition = context.fragmentMap[selection.name.value]
  224. if fragmentDefinition and fragmentDefinition.typeCondition then
  225. detectCycles(fragmentDefinition.selectionSet)
  226. end
  227. end
  228. end
  229. end
  230. detectCycles(node.selectionSet)
  231. end
  232. function rules.fragmentSpreadIsPossible(node, context)
  233. local fragment = node.kind == 'inlineFragment' and node or context.fragmentMap[node.name.value]
  234. local parentType = context.objects[#context.objects - 1]
  235. local fragmentType
  236. if node.kind == 'inlineFragment' then
  237. fragmentType = node.typeCondition and context.schema:getType(node.typeCondition.name.value) or parentType
  238. else
  239. fragmentType = context.schema:getType(fragment.typeCondition.name.value)
  240. end
  241. -- Some types are not present in the schema. Let other rules handle this.
  242. if not parentType or not fragmentType then return end
  243. local function getTypes(kind)
  244. if kind.__type == 'Object' then
  245. return { [kind] = kind }
  246. elseif kind.__type == 'Interface' then
  247. return context.schema:getImplementors(kind.name)
  248. elseif kind.__type == 'Union' then
  249. local types = {}
  250. for i = 1, #kind.types do
  251. types[kind.types[i]] = kind.types[i]
  252. end
  253. return types
  254. end
  255. end
  256. local parentTypes = getTypes(parentType)
  257. local fragmentTypes = getTypes(fragmentType)
  258. local valid = util.find(parentTypes, function(kind)
  259. return fragmentTypes[kind]
  260. end)
  261. if not valid then
  262. error('Fragment type condition is not possible for given type')
  263. end
  264. end
  265. function rules.uniqueInputObjectFields(node, context)
  266. local function validateValue(value)
  267. if value.kind == 'listType' or value.kind == 'nonNullType' then
  268. return validateValue(value.type)
  269. elseif value.kind == 'inputObject' then
  270. local fieldMap = {}
  271. for _, field in ipairs(value.values) do
  272. if fieldMap[field.name] then
  273. error('Multiple input object fields named "' .. field.name .. '"')
  274. end
  275. fieldMap[field.name] = true
  276. validateValue(field.value)
  277. end
  278. end
  279. end
  280. validateValue(node.value)
  281. end
  282. function rules.directivesAreDefined(node, context)
  283. if not node.directives then return end
  284. for _, directive in pairs(node.directives) do
  285. if not context.schema:getDirective(directive.name.value) then
  286. error('Unknown directive "' .. directive.name.value .. '"')
  287. end
  288. end
  289. end
  290. function rules.variablesHaveCorrectType(node, context)
  291. local function validateType(type)
  292. if type.kind == 'listType' or type.kind == 'nonNullType' then
  293. validateType(type.type)
  294. elseif type.kind == 'namedType' then
  295. local schemaType = context.schema:getType(type.name.value)
  296. if not schemaType then
  297. error('Variable specifies unknown type "' .. tostring(type.name.value) .. '"')
  298. elseif schemaType.__type ~= 'Scalar' and schemaType.__type ~= 'Enum' and schemaType.__type ~= 'InputObject' then
  299. error('Variable types must be scalars, enums, or input objects, got "' .. schemaType.__type .. '"')
  300. end
  301. end
  302. end
  303. if node.variableDefinitions then
  304. for _, definition in ipairs(node.variableDefinitions) do
  305. validateType(definition.type)
  306. end
  307. end
  308. end
  309. function rules.variableDefaultValuesHaveCorrectType(node, context)
  310. if node.variableDefinitions then
  311. for _, definition in ipairs(node.variableDefinitions) do
  312. if definition.type.kind == 'nonNullType' and definition.defaultValue then
  313. error('Non-null variables can not have default values')
  314. elseif definition.defaultValue then
  315. util.coerceValue(definition.defaultValue, context.schema:getType(definition.type.name.value))
  316. end
  317. end
  318. end
  319. end
  320. function rules.variablesAreUsed(node, context)
  321. if node.variableDefinitions then
  322. for _, definition in ipairs(node.variableDefinitions) do
  323. local variableName = definition.variable.name.value
  324. if not context.variableReferences[variableName] then
  325. error('Unused variable "' .. variableName .. '"')
  326. end
  327. end
  328. end
  329. end
  330. function rules.variablesAreDefined(node, context)
  331. if context.variableReferences then
  332. local variableMap = {}
  333. for _, definition in ipairs(node.variableDefinitions or {}) do
  334. variableMap[definition.variable.name.value] = true
  335. end
  336. for variable in pairs(context.variableReferences) do
  337. if not variableMap[variable] then
  338. error('Unknown variable "' .. variable .. '"')
  339. end
  340. end
  341. end
  342. end
  343. function rules.variableUsageAllowed(node, context)
  344. if context.currentOperation then
  345. local variableMap = {}
  346. for _, definition in ipairs(context.currentOperation.variableDefinitions or {}) do
  347. variableMap[definition.variable.name.value] = definition
  348. end
  349. local arguments
  350. if node.kind == 'field' then
  351. arguments = { [node.name.value] = node.arguments }
  352. elseif node.kind == 'fragmentSpread' then
  353. local seen = {}
  354. local function collectArguments(referencedNode)
  355. if referencedNode.kind == 'selectionSet' then
  356. for _, selection in ipairs(referencedNode.selections) do
  357. if not seen[selection] then
  358. seen[selection] = true
  359. collectArguments(selection)
  360. end
  361. end
  362. elseif referencedNode.kind == 'field' and referencedNode.arguments then
  363. local fieldName = referencedNode.name.value
  364. arguments[fieldName] = arguments[fieldName] or {}
  365. for _, argument in ipairs(referencedNode.arguments) do
  366. table.insert(arguments[fieldName], argument)
  367. end
  368. elseif referencedNode.kind == 'inlineFragment' then
  369. return collectArguments(referencedNode.selectionSet)
  370. elseif referencedNode.kind == 'fragmentSpread' then
  371. local fragment = context.fragmentMap[referencedNode.name.value]
  372. return fragment and collectArguments(fragment.selectionSet)
  373. end
  374. end
  375. local fragment = context.fragmentMap[node.name.value]
  376. if fragment then
  377. arguments = {}
  378. collectArguments(fragment.selectionSet)
  379. end
  380. end
  381. if not arguments then return end
  382. for field in pairs(arguments) do
  383. local parentField = context.objects[#context.objects - 1].fields[field]
  384. for i = 1, #arguments[field] do
  385. local argument = arguments[field][i]
  386. if argument.value.kind == 'variable' then
  387. local argumentType = parentField.arguments[argument.name.value]
  388. local variableName = argument.value.name.value
  389. local variableDefinition = variableMap[variableName]
  390. local hasDefault = variableDefinition.defaultValue ~= nil
  391. local function typeFromAST(variable)
  392. local innerType
  393. if variable.kind == 'listType' then
  394. innerType = typeFromAST(variable.type)
  395. return innerType and types.list(innerType)
  396. elseif variable.kind == 'nonNullType' then
  397. innerType = typeFromAST(variable.type)
  398. return innerType and types.nonNull(innerType)
  399. else
  400. assert(variable.kind == 'namedType', 'Variable must be a named type')
  401. return context.schema:getType(variable.name.value)
  402. end
  403. end
  404. local variableType = typeFromAST(variableDefinition.type)
  405. if hasDefault and variableType.__type ~= 'NonNull' then
  406. variableType = types.nonNull(variableType)
  407. end
  408. local function isTypeSubTypeOf(subType, superType)
  409. if subType == superType then return true end
  410. if superType.__type == 'NonNull' then
  411. if subType.__type == 'NonNull' then
  412. return isTypeSubTypeOf(subType.ofType, superType.ofType)
  413. end
  414. return false
  415. elseif subType.__type == 'NonNull' then
  416. return isTypeSubTypeOf(subType.ofType, superType)
  417. end
  418. if superType.__type == 'List' then
  419. if subType.__type == 'List' then
  420. return isTypeSubTypeOf(subType.ofType, superType.ofType)
  421. end
  422. return false
  423. elseif subType.__type == 'List' then
  424. return false
  425. end
  426. if subType.__type ~= 'Object' then return false end
  427. if superType.__type == 'Interface' then
  428. local implementors = context.schema:getImplementors(superType.name)
  429. return implementors and implementors[context.schema:getType(subType.name)]
  430. elseif superType.__type == 'Union' then
  431. local types = superType.types
  432. for i = 1, #types do
  433. if types[i] == subType then
  434. return true
  435. end
  436. end
  437. return false
  438. end
  439. return false
  440. end
  441. if not isTypeSubTypeOf(variableType, argumentType) then
  442. error('Variable type mismatch')
  443. end
  444. end
  445. end
  446. end
  447. end
  448. end
  449. return rules