rules.lua 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. local path = (...):gsub('%.[^%.]+$', '')
  2. local types = require(path .. '.types')
  3. local util = require(path .. '.util')
  4. local schema = require(path .. '.schema')
  5. local introspection = require(path .. '.introspection')
  6. local function getParentField(context, name, count)
  7. if introspection.fieldMap[name] then return introspection.fieldMap[name] end
  8. count = count or 1
  9. local parent = context.objects[#context.objects - count]
  10. -- Unwrap lists and non-null types
  11. while parent.ofType do
  12. parent = parent.ofType
  13. end
  14. return parent.fields[name]
  15. end
  16. local rules = {}
  17. function rules.uniqueOperationNames(node, context)
  18. local name = node.name and node.name.value
  19. if name then
  20. if context.operationNames[name] then
  21. error('Multiple operations exist named "' .. name .. '"')
  22. end
  23. context.operationNames[name] = true
  24. end
  25. end
  26. function rules.loneAnonymousOperation(node, context)
  27. local name = node.name and node.name.value
  28. if context.hasAnonymousOperation or (not name and next(context.operationNames)) then
  29. error('Cannot have more than one operation when using anonymous operations')
  30. end
  31. if not name then
  32. context.hasAnonymousOperation = true
  33. end
  34. end
  35. function rules.fieldsDefinedOnType(node, context)
  36. if context.objects[#context.objects] == false then
  37. local parent = context.objects[#context.objects - 1]
  38. while parent.ofType do parent = parent.ofType end
  39. error('Field "' .. node.name.value .. '" is not defined on type "' .. parent.name .. '"')
  40. end
  41. end
  42. function rules.argumentsDefinedOnType(node, context)
  43. if node.arguments then
  44. local parentField = getParentField(context, node.name.value)
  45. for _, argument in pairs(node.arguments) do
  46. local name = argument.name.value
  47. if not parentField.arguments[name] then
  48. error('Non-existent argument "' .. name .. '"')
  49. end
  50. end
  51. end
  52. end
  53. function rules.scalarFieldsAreLeaves(node, context)
  54. if context.objects[#context.objects].__type == 'Scalar' and node.selectionSet then
  55. error('Scalar values cannot have subselections')
  56. end
  57. end
  58. function rules.compositeFieldsAreNotLeaves(node, context)
  59. local _type = context.objects[#context.objects].__type
  60. local isCompositeType = _type == 'Object' or _type == 'Interface' or _type == 'Union'
  61. if isCompositeType and not node.selectionSet then
  62. error('Composite types must have subselections')
  63. end
  64. end
  65. function rules.unambiguousSelections(node, context)
  66. local selectionMap = {}
  67. local seen = {}
  68. local function findConflict(entryA, entryB)
  69. -- Parent types can't overlap if they're different objects.
  70. -- Interface and union types may overlap.
  71. if entryA.parent ~= entryB.parent and entryA.__type == 'Object' and entryB.__type == 'Object' then
  72. return
  73. end
  74. -- Error if there are aliases that map two different fields to the same name.
  75. if entryA.field.name.value ~= entryB.field.name.value then
  76. return 'Type name mismatch'
  77. end
  78. -- Error if there are fields with the same name that have different return types.
  79. if entryA.definition and entryB.definition and entryA.definition ~= entryB.definition then
  80. return 'Return type mismatch'
  81. end
  82. -- Error if arguments are not identical for two fields with the same name.
  83. local argsA = entryA.field.arguments or {}
  84. local argsB = entryB.field.arguments or {}
  85. if #argsA ~= #argsB then
  86. return 'Argument mismatch'
  87. end
  88. local argMap = {}
  89. for i = 1, #argsA do
  90. argMap[argsA[i].name.value] = argsA[i].value
  91. end
  92. for i = 1, #argsB do
  93. local name = argsB[i].name.value
  94. if not argMap[name] then
  95. return 'Argument mismatch'
  96. elseif argMap[name].kind ~= argsB[i].value.kind then
  97. return 'Argument mismatch'
  98. elseif argMap[name].value ~= argsB[i].value.value then
  99. return 'Argument mismatch'
  100. end
  101. end
  102. end
  103. local function validateField(key, entry)
  104. if selectionMap[key] then
  105. for i = 1, #selectionMap[key] do
  106. local conflict = findConflict(selectionMap[key][i], entry)
  107. if conflict then
  108. error(conflict)
  109. end
  110. end
  111. table.insert(selectionMap[key], entry)
  112. else
  113. selectionMap[key] = { entry }
  114. end
  115. end
  116. -- Recursively make sure that there are no ambiguous selections with the same name.
  117. local function validateSelectionSet(selectionSet, parentType)
  118. for _, selection in ipairs(selectionSet.selections) do
  119. if selection.kind == 'field' then
  120. if not parentType or not parentType.fields or not parentType.fields[selection.name.value] then return end
  121. local key = selection.alias and selection.alias.name.value or selection.name.value
  122. local definition = parentType.fields[selection.name.value].kind
  123. local fieldEntry = {
  124. parent = parentType,
  125. field = selection,
  126. definition = definition
  127. }
  128. validateField(key, fieldEntry)
  129. elseif selection.kind == 'inlineFragment' then
  130. local parentType = selection.typeCondition and context.schema:getType(selection.typeCondition.name.value) or parentType
  131. validateSelectionSet(selection.selectionSet, parentType)
  132. elseif selection.kind == 'fragmentSpread' then
  133. local fragmentDefinition = context.fragmentMap[selection.name.value]
  134. if fragmentDefinition and not seen[fragmentDefinition] then
  135. seen[fragmentDefinition] = true
  136. if fragmentDefinition and fragmentDefinition.typeCondition then
  137. local parentType = context.schema:getType(fragmentDefinition.typeCondition.name.value)
  138. validateSelectionSet(fragmentDefinition.selectionSet, parentType)
  139. end
  140. end
  141. end
  142. end
  143. end
  144. validateSelectionSet(node, context.objects[#context.objects])
  145. end
  146. function rules.uniqueArgumentNames(node, context)
  147. if node.arguments then
  148. local arguments = {}
  149. for _, argument in ipairs(node.arguments) do
  150. local name = argument.name.value
  151. if arguments[name] then
  152. error('Encountered multiple arguments named "' .. name .. '"')
  153. end
  154. arguments[name] = true
  155. end
  156. end
  157. end
  158. function rules.argumentsOfCorrectType(node, context)
  159. if node.arguments then
  160. local parentField = getParentField(context, node.name.value)
  161. for _, argument in pairs(node.arguments) do
  162. local name = argument.name.value
  163. local argumentType = parentField.arguments[name]
  164. util.coerceValue(argument.value, argumentType.kind or argumentType)
  165. end
  166. end
  167. end
  168. function rules.requiredArgumentsPresent(node, context)
  169. local arguments = node.arguments or {}
  170. local parentField = getParentField(context, node.name.value)
  171. for name, argument in pairs(parentField.arguments) do
  172. if argument.__type == 'NonNull' then
  173. local present = util.find(arguments, function(argument)
  174. return argument.name.value == name
  175. end)
  176. if not present then
  177. error('Required argument "' .. name .. '" was not supplied.')
  178. end
  179. end
  180. end
  181. end
  182. function rules.uniqueFragmentNames(node, context)
  183. local fragments = {}
  184. for _, definition in ipairs(node.definitions) do
  185. if definition.kind == 'fragmentDefinition' then
  186. local name = definition.name.value
  187. if fragments[name] then
  188. error('Encountered multiple fragments named "' .. name .. '"')
  189. end
  190. fragments[name] = true
  191. end
  192. end
  193. end
  194. function rules.fragmentHasValidType(node, context)
  195. if not node.typeCondition then return end
  196. local name = node.typeCondition.name.value
  197. local kind = context.schema:getType(name)
  198. if not kind then
  199. error('Fragment refers to non-existent type "' .. name .. '"')
  200. end
  201. if kind.__type ~= 'Object' and kind.__type ~= 'Interface' and kind.__type ~= 'Union' then
  202. error('Fragment type must be an Object, Interface, or Union, got ' .. kind.__type)
  203. end
  204. end
  205. function rules.noUnusedFragments(node, context)
  206. for _, definition in ipairs(node.definitions) do
  207. if definition.kind == 'fragmentDefinition' then
  208. local name = definition.name.value
  209. if not context.usedFragments[name] then
  210. error('Fragment "' .. name .. '" was not used.')
  211. end
  212. end
  213. end
  214. end
  215. function rules.fragmentSpreadTargetDefined(node, context)
  216. if not context.fragmentMap[node.name.value] then
  217. error('Fragment spread refers to non-existent fragment "' .. node.name.value .. '"')
  218. end
  219. end
  220. function rules.fragmentDefinitionHasNoCycles(node, context)
  221. local seen = { [node.name.value] = true }
  222. local function detectCycles(selectionSet)
  223. for _, selection in ipairs(selectionSet.selections) do
  224. if selection.kind == 'inlineFragment' then
  225. detectCycles(selection.selectionSet)
  226. elseif selection.kind == 'fragmentSpread' then
  227. if seen[selection.name.value] then
  228. error('Fragment definition has cycles')
  229. end
  230. seen[selection.name.value] = true
  231. local fragmentDefinition = context.fragmentMap[selection.name.value]
  232. if fragmentDefinition and fragmentDefinition.typeCondition then
  233. detectCycles(fragmentDefinition.selectionSet)
  234. end
  235. end
  236. end
  237. end
  238. detectCycles(node.selectionSet)
  239. end
  240. function rules.fragmentSpreadIsPossible(node, context)
  241. local fragment = node.kind == 'inlineFragment' and node or context.fragmentMap[node.name.value]
  242. local parentType = context.objects[#context.objects - 1]
  243. while parentType.ofType do parentType = parentType.ofType end
  244. local fragmentType
  245. if node.kind == 'inlineFragment' then
  246. fragmentType = node.typeCondition and context.schema:getType(node.typeCondition.name.value) or parentType
  247. else
  248. fragmentType = context.schema:getType(fragment.typeCondition.name.value)
  249. end
  250. -- Some types are not present in the schema. Let other rules handle this.
  251. if not parentType or not fragmentType then return end
  252. local function getTypes(kind)
  253. if kind.__type == 'Object' then
  254. return { [kind] = kind }
  255. elseif kind.__type == 'Interface' then
  256. return context.schema:getImplementors(kind.name)
  257. elseif kind.__type == 'Union' then
  258. local types = {}
  259. for i = 1, #kind.types do
  260. types[kind.types[i]] = kind.types[i]
  261. end
  262. return types
  263. else
  264. return {}
  265. end
  266. end
  267. local parentTypes = getTypes(parentType)
  268. local fragmentTypes = getTypes(fragmentType)
  269. local valid = util.find(parentTypes, function(kind)
  270. return fragmentTypes[kind]
  271. end)
  272. if not valid then
  273. error('Fragment type condition is not possible for given type')
  274. end
  275. end
  276. function rules.uniqueInputObjectFields(node, context)
  277. local function validateValue(value)
  278. if value.kind == 'listType' or value.kind == 'nonNullType' then
  279. return validateValue(value.type)
  280. elseif value.kind == 'inputObject' then
  281. local fieldMap = {}
  282. for _, field in ipairs(value.values) do
  283. if fieldMap[field.name] then
  284. error('Multiple input object fields named "' .. field.name .. '"')
  285. end
  286. fieldMap[field.name] = true
  287. validateValue(field.value)
  288. end
  289. end
  290. end
  291. validateValue(node.value)
  292. end
  293. function rules.directivesAreDefined(node, context)
  294. if not node.directives then return end
  295. for _, directive in pairs(node.directives) do
  296. if not context.schema:getDirective(directive.name.value) then
  297. error('Unknown directive "' .. directive.name.value .. '"')
  298. end
  299. end
  300. end
  301. function rules.variablesHaveCorrectType(node, context)
  302. local function validateType(type)
  303. if type.kind == 'listType' or type.kind == 'nonNullType' then
  304. validateType(type.type)
  305. elseif type.kind == 'namedType' then
  306. local schemaType = context.schema:getType(type.name.value)
  307. if not schemaType then
  308. error('Variable specifies unknown type "' .. tostring(type.name.value) .. '"')
  309. elseif schemaType.__type ~= 'Scalar' and schemaType.__type ~= 'Enum' and schemaType.__type ~= 'InputObject' then
  310. error('Variable types must be scalars, enums, or input objects, got "' .. schemaType.__type .. '"')
  311. end
  312. end
  313. end
  314. if node.variableDefinitions then
  315. for _, definition in ipairs(node.variableDefinitions) do
  316. validateType(definition.type)
  317. end
  318. end
  319. end
  320. function rules.variableDefaultValuesHaveCorrectType(node, context)
  321. if node.variableDefinitions then
  322. for _, definition in ipairs(node.variableDefinitions) do
  323. if definition.type.kind == 'nonNullType' and definition.defaultValue then
  324. error('Non-null variables can not have default values')
  325. elseif definition.defaultValue then
  326. util.coerceValue(definition.defaultValue, context.schema:getType(definition.type.name.value))
  327. end
  328. end
  329. end
  330. end
  331. function rules.variablesAreUsed(node, context)
  332. if node.variableDefinitions then
  333. for _, definition in ipairs(node.variableDefinitions) do
  334. local variableName = definition.variable.name.value
  335. if not context.variableReferences[variableName] then
  336. error('Unused variable "' .. variableName .. '"')
  337. end
  338. end
  339. end
  340. end
  341. function rules.variablesAreDefined(node, context)
  342. if context.variableReferences then
  343. local variableMap = {}
  344. for _, definition in ipairs(node.variableDefinitions or {}) do
  345. variableMap[definition.variable.name.value] = true
  346. end
  347. for variable in pairs(context.variableReferences) do
  348. if not variableMap[variable] then
  349. error('Unknown variable "' .. variable .. '"')
  350. end
  351. end
  352. end
  353. end
  354. function rules.variableUsageAllowed(node, context)
  355. if context.currentOperation then
  356. local variableMap = {}
  357. for _, definition in ipairs(context.currentOperation.variableDefinitions or {}) do
  358. variableMap[definition.variable.name.value] = definition
  359. end
  360. local arguments
  361. if node.kind == 'field' then
  362. arguments = { [node.name.value] = node.arguments }
  363. elseif node.kind == 'fragmentSpread' then
  364. local seen = {}
  365. local function collectArguments(referencedNode)
  366. if referencedNode.kind == 'selectionSet' then
  367. for _, selection in ipairs(referencedNode.selections) do
  368. if not seen[selection] then
  369. seen[selection] = true
  370. collectArguments(selection)
  371. end
  372. end
  373. elseif referencedNode.kind == 'field' and referencedNode.arguments then
  374. local fieldName = referencedNode.name.value
  375. arguments[fieldName] = arguments[fieldName] or {}
  376. for _, argument in ipairs(referencedNode.arguments) do
  377. table.insert(arguments[fieldName], argument)
  378. end
  379. elseif referencedNode.kind == 'inlineFragment' then
  380. return collectArguments(referencedNode.selectionSet)
  381. elseif referencedNode.kind == 'fragmentSpread' then
  382. local fragment = context.fragmentMap[referencedNode.name.value]
  383. return fragment and collectArguments(fragment.selectionSet)
  384. end
  385. end
  386. local fragment = context.fragmentMap[node.name.value]
  387. if fragment then
  388. arguments = {}
  389. collectArguments(fragment.selectionSet)
  390. end
  391. end
  392. if not arguments then return end
  393. for field in pairs(arguments) do
  394. local parentField = getParentField(context, field)
  395. for i = 1, #arguments[field] do
  396. local argument = arguments[field][i]
  397. if argument.value.kind == 'variable' then
  398. local argumentType = parentField.arguments[argument.name.value]
  399. local variableName = argument.value.name.value
  400. local variableDefinition = variableMap[variableName]
  401. local hasDefault = variableDefinition.defaultValue ~= nil
  402. local function typeFromAST(variable)
  403. local innerType
  404. if variable.kind == 'listType' then
  405. innerType = typeFromAST(variable.type)
  406. return innerType and types.list(innerType)
  407. elseif variable.kind == 'nonNullType' then
  408. innerType = typeFromAST(variable.type)
  409. return innerType and types.nonNull(innerType)
  410. else
  411. assert(variable.kind == 'namedType', 'Variable must be a named type')
  412. return context.schema:getType(variable.name.value)
  413. end
  414. end
  415. local variableType = typeFromAST(variableDefinition.type)
  416. if hasDefault and variableType.__type ~= 'NonNull' then
  417. variableType = types.nonNull(variableType)
  418. end
  419. local function isTypeSubTypeOf(subType, superType)
  420. if subType == superType then return true end
  421. if superType.__type == 'NonNull' then
  422. if subType.__type == 'NonNull' then
  423. return isTypeSubTypeOf(subType.ofType, superType.ofType)
  424. end
  425. return false
  426. elseif subType.__type == 'NonNull' then
  427. return isTypeSubTypeOf(subType.ofType, superType)
  428. end
  429. if superType.__type == 'List' then
  430. if subType.__type == 'List' then
  431. return isTypeSubTypeOf(subType.ofType, superType.ofType)
  432. end
  433. return false
  434. elseif subType.__type == 'List' then
  435. return false
  436. end
  437. if subType.__type ~= 'Object' then return false end
  438. if superType.__type == 'Interface' then
  439. local implementors = context.schema:getImplementors(superType.name)
  440. return implementors and implementors[context.schema:getType(subType.name)]
  441. elseif superType.__type == 'Union' then
  442. local types = superType.types
  443. for i = 1, #types do
  444. if types[i] == subType then
  445. return true
  446. end
  447. end
  448. return false
  449. end
  450. return false
  451. end
  452. if not isTypeSubTypeOf(variableType, argumentType) then
  453. error('Variable type mismatch')
  454. end
  455. end
  456. end
  457. end
  458. end
  459. end
  460. return rules