rules.lua 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. local path = (...):gsub('%.[^%.]+$', '')
  2. local types = require(path .. '.types')
  3. local util = require(path .. '.util')
  4. local rules = {}
  5. function rules.uniqueOperationNames(node, context)
  6. local name = node.name and node.name.value
  7. if name then
  8. if context.operationNames[name] then
  9. error('Multiple operations exist named "' .. name .. '"')
  10. end
  11. context.operationNames[name] = true
  12. end
  13. end
  14. function rules.loneAnonymousOperation(node, context)
  15. local name = node.name and node.name.value
  16. if context.hasAnonymousOperation or (not name and next(context.operationNames)) then
  17. error('Cannot have more than one operation when using anonymous operations')
  18. end
  19. if not name then
  20. context.hasAnonymousOperation = true
  21. end
  22. end
  23. function rules.fieldsDefinedOnType(node, context)
  24. if context.objects[#context.objects] == false then
  25. local parent = context.objects[#context.objects - 1]
  26. if(parent.__type == 'List') then
  27. parent = parent.ofType
  28. end
  29. error('Field "' .. node.name.value .. '" is not defined on type "' .. parent.name .. '"')
  30. end
  31. end
  32. function rules.argumentsDefinedOnType(node, context)
  33. if node.arguments then
  34. local parentField = util.getParentField(context, node.name.value, 1)
  35. for _, argument in pairs(node.arguments) do
  36. local name = argument.name.value
  37. if not parentField.arguments[name] then
  38. error('Non-existent argument "' .. name .. '"')
  39. end
  40. end
  41. end
  42. end
  43. function rules.scalarFieldsAreLeaves(node, context)
  44. if context.objects[#context.objects].__type == 'Scalar' and node.selectionSet then
  45. error('Scalar values cannot have subselections')
  46. end
  47. end
  48. function rules.compositeFieldsAreNotLeaves(node, context)
  49. local _type = context.objects[#context.objects].__type
  50. local isCompositeType = _type == 'Object' or _type == 'Interface' or _type == 'Union'
  51. if isCompositeType and not node.selectionSet then
  52. error('Composite types must have subselections')
  53. end
  54. end
  55. function rules.unambiguousSelections(node, context)
  56. local selectionMap = {}
  57. local seen = {}
  58. local function findConflict(entryA, entryB)
  59. -- Parent types can't overlap if they're different objects.
  60. -- Interface and union types may overlap.
  61. if entryA.parent ~= entryB.parent and entryA.__type == 'Object' and entryB.__type == 'Object' then
  62. return
  63. end
  64. -- Error if there are aliases that map two different fields to the same name.
  65. if entryA.field.name.value ~= entryB.field.name.value then
  66. return 'Type name mismatch'
  67. end
  68. -- Error if there are fields with the same name that have different return types.
  69. if entryA.definition and entryB.definition and entryA.definition ~= entryB.definition then
  70. return 'Return type mismatch'
  71. end
  72. -- Error if arguments are not identical for two fields with the same name.
  73. local argsA = entryA.field.arguments or {}
  74. local argsB = entryB.field.arguments or {}
  75. if #argsA ~= #argsB then
  76. return 'Argument mismatch'
  77. end
  78. local argMap = {}
  79. for i = 1, #argsA do
  80. argMap[argsA[i].name.value] = argsA[i].value
  81. end
  82. for i = 1, #argsB do
  83. local name = argsB[i].name.value
  84. if not argMap[name] then
  85. return 'Argument mismatch'
  86. elseif argMap[name].kind ~= argsB[i].value.kind then
  87. return 'Argument mismatch'
  88. elseif argMap[name].value ~= argsB[i].value.value then
  89. return 'Argument mismatch'
  90. end
  91. end
  92. end
  93. local function validateField(key, entry)
  94. if selectionMap[key] then
  95. for i = 1, #selectionMap[key] do
  96. local conflict = findConflict(selectionMap[key][i], entry)
  97. if conflict then
  98. error(conflict)
  99. end
  100. end
  101. table.insert(selectionMap[key], entry)
  102. else
  103. selectionMap[key] = { entry }
  104. end
  105. end
  106. -- Recursively make sure that there are no ambiguous selections with the same name.
  107. local function validateSelectionSet(selectionSet, parentType)
  108. for _, selection in ipairs(selectionSet.selections) do
  109. if selection.kind == 'field' then
  110. if not parentType or not parentType.fields or not parentType.fields[selection.name.value] then return end
  111. local key = selection.alias and selection.alias.name.value or selection.name.value
  112. local definition = parentType.fields[selection.name.value].kind
  113. local fieldEntry = {
  114. parent = parentType,
  115. field = selection,
  116. definition = definition
  117. }
  118. validateField(key, fieldEntry)
  119. elseif selection.kind == 'inlineFragment' then
  120. local parentType = selection.typeCondition and context.schema:getType(selection.typeCondition.name.value) or parentType
  121. validateSelectionSet(selection.selectionSet, parentType)
  122. elseif selection.kind == 'fragmentSpread' then
  123. local fragmentDefinition = context.fragmentMap[selection.name.value]
  124. if fragmentDefinition and not seen[fragmentDefinition] then
  125. seen[fragmentDefinition] = true
  126. if fragmentDefinition and fragmentDefinition.typeCondition then
  127. local parentType = context.schema:getType(fragmentDefinition.typeCondition.name.value)
  128. validateSelectionSet(fragmentDefinition.selectionSet, parentType)
  129. end
  130. end
  131. end
  132. end
  133. end
  134. validateSelectionSet(node, context.objects[#context.objects])
  135. end
  136. function rules.uniqueArgumentNames(node, context)
  137. if node.arguments then
  138. local arguments = {}
  139. for _, argument in ipairs(node.arguments) do
  140. local name = argument.name.value
  141. if arguments[name] then
  142. error('Encountered multiple arguments named "' .. name .. '"')
  143. end
  144. arguments[name] = true
  145. end
  146. end
  147. end
  148. function rules.argumentsOfCorrectType(node, context)
  149. if node.arguments then
  150. local parentField = util.getParentField(context, node.name.value, 1)
  151. for _, argument in pairs(node.arguments) do
  152. local name = argument.name.value
  153. local argumentType = parentField.arguments[name]
  154. util.coerceValue(argument.value, argumentType)
  155. end
  156. end
  157. end
  158. function rules.requiredArgumentsPresent(node, context)
  159. local arguments = node.arguments or {}
  160. local parentField = util.getParentField(context, node.name.value, 1)
  161. for name, argument in pairs(parentField.arguments) do
  162. if argument.__type == 'NonNull' then
  163. local present = util.find(arguments, function(argument)
  164. return argument.name.value == name
  165. end)
  166. if not present then
  167. error('Required argument "' .. name .. '" was not supplied.')
  168. end
  169. end
  170. end
  171. end
  172. function rules.uniqueFragmentNames(node, context)
  173. local fragments = {}
  174. for _, definition in ipairs(node.definitions) do
  175. if definition.kind == 'fragmentDefinition' then
  176. local name = definition.name.value
  177. if fragments[name] then
  178. error('Encountered multiple fragments named "' .. name .. '"')
  179. end
  180. fragments[name] = true
  181. end
  182. end
  183. end
  184. function rules.fragmentHasValidType(node, context)
  185. if not node.typeCondition then return end
  186. local name = node.typeCondition.name.value
  187. local kind = context.schema:getType(name)
  188. if not kind then
  189. error('Fragment refers to non-existent type "' .. name .. '"')
  190. end
  191. if kind.__type ~= 'Object' and kind.__type ~= 'Interface' and kind.__type ~= 'Union' then
  192. error('Fragment type must be an Object, Interface, or Union, got ' .. kind.__type)
  193. end
  194. end
  195. function rules.noUnusedFragments(node, context)
  196. for _, definition in ipairs(node.definitions) do
  197. if definition.kind == 'fragmentDefinition' then
  198. local name = definition.name.value
  199. if not context.usedFragments[name] then
  200. error('Fragment "' .. name .. '" was not used.')
  201. end
  202. end
  203. end
  204. end
  205. function rules.fragmentSpreadTargetDefined(node, context)
  206. if not context.fragmentMap[node.name.value] then
  207. error('Fragment spread refers to non-existent fragment "' .. node.name.value .. '"')
  208. end
  209. end
  210. function rules.fragmentDefinitionHasNoCycles(node, context)
  211. local seen = { [node.name.value] = true }
  212. local function detectCycles(selectionSet)
  213. for _, selection in ipairs(selectionSet.selections) do
  214. if selection.kind == 'inlineFragment' then
  215. detectCycles(selection.selectionSet)
  216. elseif selection.kind == 'fragmentSpread' then
  217. if seen[selection.name.value] then
  218. error('Fragment definition has cycles')
  219. end
  220. seen[selection.name.value] = true
  221. local fragmentDefinition = context.fragmentMap[selection.name.value]
  222. if fragmentDefinition and fragmentDefinition.typeCondition then
  223. detectCycles(fragmentDefinition.selectionSet)
  224. end
  225. end
  226. end
  227. end
  228. detectCycles(node.selectionSet)
  229. end
  230. function rules.fragmentSpreadIsPossible(node, context)
  231. local fragment = node.kind == 'inlineFragment' and node or context.fragmentMap[node.name.value]
  232. local parentType = context.objects[#context.objects - 1]
  233. if(parent.__type == 'List') then
  234. parent = parent.ofType
  235. end
  236. local fragmentType
  237. if node.kind == 'inlineFragment' then
  238. fragmentType = node.typeCondition and context.schema:getType(node.typeCondition.name.value) or parentType
  239. else
  240. fragmentType = context.schema:getType(fragment.typeCondition.name.value)
  241. end
  242. -- Some types are not present in the schema. Let other rules handle this.
  243. if not parentType or not fragmentType then return end
  244. local function getTypes(kind)
  245. if kind.__type == 'Object' then
  246. return { [kind] = kind }
  247. elseif kind.__type == 'Interface' then
  248. return context.schema:getImplementors(kind.name)
  249. elseif kind.__type == 'Union' then
  250. local types = {}
  251. for i = 1, #kind.types do
  252. types[kind.types[i]] = kind.types[i]
  253. end
  254. return types
  255. end
  256. end
  257. local parentTypes = getTypes(parentType)
  258. local fragmentTypes = getTypes(fragmentType)
  259. local valid = util.find(parentTypes, function(kind)
  260. return fragmentTypes[kind]
  261. end)
  262. if not valid then
  263. error('Fragment type condition is not possible for given type')
  264. end
  265. end
  266. function rules.uniqueInputObjectFields(node, context)
  267. local function validateValue(value)
  268. if value.kind == 'listType' or value.kind == 'nonNullType' then
  269. return validateValue(value.type)
  270. elseif value.kind == 'inputObject' then
  271. local fieldMap = {}
  272. for _, field in ipairs(value.values) do
  273. if fieldMap[field.name] then
  274. error('Multiple input object fields named "' .. field.name .. '"')
  275. end
  276. fieldMap[field.name] = true
  277. validateValue(field.value)
  278. end
  279. end
  280. end
  281. validateValue(node.value)
  282. end
  283. function rules.directivesAreDefined(node, context)
  284. if not node.directives then return end
  285. for _, directive in pairs(node.directives) do
  286. if not context.schema:getDirective(directive.name.value) then
  287. error('Unknown directive "' .. directive.name.value .. '"')
  288. end
  289. end
  290. end
  291. function rules.variablesHaveCorrectType(node, context)
  292. local function validateType(type)
  293. if type.kind == 'listType' or type.kind == 'nonNullType' then
  294. validateType(type.type)
  295. elseif type.kind == 'namedType' then
  296. local schemaType = context.schema:getType(type.name.value)
  297. if not schemaType then
  298. error('Variable specifies unknown type "' .. tostring(type.name.value) .. '"')
  299. elseif schemaType.__type ~= 'Scalar' and schemaType.__type ~= 'Enum' and schemaType.__type ~= 'InputObject' then
  300. error('Variable types must be scalars, enums, or input objects, got "' .. schemaType.__type .. '"')
  301. end
  302. end
  303. end
  304. if node.variableDefinitions then
  305. for _, definition in ipairs(node.variableDefinitions) do
  306. validateType(definition.type)
  307. end
  308. end
  309. end
  310. function rules.variableDefaultValuesHaveCorrectType(node, context)
  311. if node.variableDefinitions then
  312. for _, definition in ipairs(node.variableDefinitions) do
  313. if definition.type.kind == 'nonNullType' and definition.defaultValue then
  314. error('Non-null variables can not have default values')
  315. elseif definition.defaultValue then
  316. util.coerceValue(definition.defaultValue, context.schema:getType(definition.type.name.value))
  317. end
  318. end
  319. end
  320. end
  321. function rules.variablesAreUsed(node, context)
  322. if node.variableDefinitions then
  323. for _, definition in ipairs(node.variableDefinitions) do
  324. local variableName = definition.variable.name.value
  325. if not context.variableReferences[variableName] then
  326. error('Unused variable "' .. variableName .. '"')
  327. end
  328. end
  329. end
  330. end
  331. function rules.variablesAreDefined(node, context)
  332. if context.variableReferences then
  333. local variableMap = {}
  334. for _, definition in ipairs(node.variableDefinitions or {}) do
  335. variableMap[definition.variable.name.value] = true
  336. end
  337. for variable in pairs(context.variableReferences) do
  338. if not variableMap[variable] then
  339. error('Unknown variable "' .. variable .. '"')
  340. end
  341. end
  342. end
  343. end
  344. function rules.variableUsageAllowed(node, context)
  345. if context.currentOperation then
  346. local variableMap = {}
  347. for _, definition in ipairs(context.currentOperation.variableDefinitions or {}) do
  348. variableMap[definition.variable.name.value] = definition
  349. end
  350. local arguments
  351. if node.kind == 'field' then
  352. arguments = { [node.name.value] = node.arguments }
  353. elseif node.kind == 'fragmentSpread' then
  354. local seen = {}
  355. local function collectArguments(referencedNode)
  356. if referencedNode.kind == 'selectionSet' then
  357. for _, selection in ipairs(referencedNode.selections) do
  358. if not seen[selection] then
  359. seen[selection] = true
  360. collectArguments(selection)
  361. end
  362. end
  363. elseif referencedNode.kind == 'field' and referencedNode.arguments then
  364. local fieldName = referencedNode.name.value
  365. arguments[fieldName] = arguments[fieldName] or {}
  366. for _, argument in ipairs(referencedNode.arguments) do
  367. table.insert(arguments[fieldName], argument)
  368. end
  369. elseif referencedNode.kind == 'inlineFragment' then
  370. return collectArguments(referencedNode.selectionSet)
  371. elseif referencedNode.kind == 'fragmentSpread' then
  372. local fragment = context.fragmentMap[referencedNode.name.value]
  373. return fragment and collectArguments(fragment.selectionSet)
  374. end
  375. end
  376. local fragment = context.fragmentMap[node.name.value]
  377. if fragment then
  378. arguments = {}
  379. collectArguments(fragment.selectionSet)
  380. end
  381. end
  382. if not arguments then return end
  383. for field in pairs(arguments) do
  384. local parentField = util.getParentField(context, field, 1)
  385. for i = 1, #arguments[field] do
  386. local argument = arguments[field][i]
  387. if argument.value.kind == 'variable' then
  388. local argumentType = parentField.arguments[argument.name.value]
  389. local variableName = argument.value.name.value
  390. local variableDefinition = variableMap[variableName]
  391. local hasDefault = variableDefinition.defaultValue ~= nil
  392. local function typeFromAST(variable)
  393. local innerType
  394. if variable.kind == 'listType' then
  395. innerType = typeFromAST(variable.type)
  396. return innerType and types.list(innerType)
  397. elseif variable.kind == 'nonNullType' then
  398. innerType = typeFromAST(variable.type)
  399. return innerType and types.nonNull(innerType)
  400. else
  401. assert(variable.kind == 'namedType', 'Variable must be a named type')
  402. return context.schema:getType(variable.name.value)
  403. end
  404. end
  405. local variableType = typeFromAST(variableDefinition.type)
  406. if hasDefault and variableType.__type ~= 'NonNull' then
  407. variableType = types.nonNull(variableType)
  408. end
  409. local function isTypeSubTypeOf(subType, superType)
  410. if subType == superType then return true end
  411. if superType.__type == 'NonNull' then
  412. if subType.__type == 'NonNull' then
  413. return isTypeSubTypeOf(subType.ofType, superType.ofType)
  414. end
  415. return false
  416. elseif subType.__type == 'NonNull' then
  417. return isTypeSubTypeOf(subType.ofType, superType)
  418. end
  419. if superType.__type == 'List' then
  420. if subType.__type == 'List' then
  421. return isTypeSubTypeOf(subType.ofType, superType.ofType)
  422. end
  423. return false
  424. elseif subType.__type == 'List' then
  425. return false
  426. end
  427. if subType.__type ~= 'Object' then return false end
  428. if superType.__type == 'Interface' then
  429. local implementors = context.schema:getImplementors(superType.name)
  430. return implementors and implementors[context.schema:getType(subType.name)]
  431. elseif superType.__type == 'Union' then
  432. local types = superType.types
  433. for i = 1, #types do
  434. if types[i] == subType then
  435. return true
  436. end
  437. end
  438. return false
  439. end
  440. return false
  441. end
  442. if not isTypeSubTypeOf(variableType, argumentType) then
  443. error('Variable type mismatch')
  444. end
  445. end
  446. end
  447. end
  448. end
  449. end
  450. return rules