Browse Source

Merge pull request #10 from bjornbytes/introspection

Introspection
Bjorn Swenson 9 years ago
parent
commit
3b011a0198
8 changed files with 674 additions and 89 deletions
  1. 7 7
      graphql/execute.lua
  2. 475 0
      graphql/introspection.lua
  3. 1 1
      graphql/parse.lua
  4. 26 8
      graphql/rules.lua
  5. 64 43
      graphql/schema.lua
  6. 41 13
      graphql/types.lua
  7. 20 8
      graphql/util.lua
  8. 40 9
      graphql/validate.lua

+ 7 - 7
graphql/execute.lua

@@ -1,6 +1,7 @@
 local path = (...):gsub('%.[^%.]+$', '')
 local path = (...):gsub('%.[^%.]+$', '')
 local types = require(path .. '.types')
 local types = require(path .. '.types')
 local util = require(path .. '.util')
 local util = require(path .. '.util')
+local introspection = require(path .. '.introspection')
 
 
 local function typeFromAST(node, schema)
 local function typeFromAST(node, schema)
   local innerType
   local innerType
@@ -149,10 +150,10 @@ local function completeValue(fieldType, result, subSelections, context)
 
 
   if fieldTypeName == 'NonNull' then
   if fieldTypeName == 'NonNull' then
     local innerType = fieldType.ofType
     local innerType = fieldType.ofType
-    local completedResult = completeValue(innerType, result, context)
+    local completedResult = completeValue(innerType, result, subSelections, context)
 
 
     if completedResult == nil then
     if completedResult == nil then
-      error('No value provided for non-null ' .. innerType.name)
+      error('No value provided for non-null ' .. (innerType.name or innerType.__type))
     end
     end
 
 
     return completedResult
     return completedResult
@@ -195,7 +196,7 @@ local function getFieldEntry(objectType, object, fields, context)
   local firstField = fields[1]
   local firstField = fields[1]
   local fieldName = firstField.name.value
   local fieldName = firstField.name.value
   local responseKey = getFieldResponseKey(firstField)
   local responseKey = getFieldResponseKey(firstField)
-  local fieldType = objectType.fields[fieldName]
+  local fieldType = introspection.fieldMap[fieldName] or objectType.fields[fieldName]
 
 
   if fieldType == nil then
   if fieldType == nil then
     return nil
     return nil
@@ -206,7 +207,7 @@ local function getFieldEntry(objectType, object, fields, context)
     argumentMap[argument.name.value] = argument
     argumentMap[argument.name.value] = argument
   end
   end
 
 
-  local arguments = util.map(fieldType.arguments, function(argument, name)
+  local arguments = util.map(fieldType.arguments or {}, function(argument, name)
     local supplied = argumentMap[name] and argumentMap[name].value
     local supplied = argumentMap[name] and argumentMap[name].value
     return supplied and util.coerceValue(supplied, argument, context.variables) or argument.defaultValue
     return supplied and util.coerceValue(supplied, argument, context.variables) or argument.defaultValue
   end)
   end)
@@ -224,10 +225,9 @@ local function getFieldEntry(objectType, object, fields, context)
   }
   }
 
 
   local resolvedObject = (fieldType.resolve or defaultResolver)(object, arguments, info)
   local resolvedObject = (fieldType.resolve or defaultResolver)(object, arguments, info)
-
   local subSelections = mergeSelectionSets(fields)
   local subSelections = mergeSelectionSets(fields)
-  local responseValue = completeValue(fieldType.kind, resolvedObject, subSelections, context)
-  return responseValue
+
+  return completeValue(fieldType.kind, resolvedObject, subSelections, context)
 end
 end
 
 
 evaluateSelections = function(objectType, object, selections, context)
 evaluateSelections = function(objectType, object, selections, context)

+ 475 - 0
graphql/introspection.lua

@@ -0,0 +1,475 @@
+local path = (...):gsub('%.[^%.]+$', '')
+local types = require(path .. '.types')
+local util = require(path .. '.util')
+
+local __Schema, __Directive, __DirectiveLocation, __Type, __Field, __InputValue,__EnumValue, __TypeKind
+
+__Schema = types.object({
+  name = '__Schema',
+
+  description = util.trim [[
+    A GraphQL Schema defines the capabilities of a GraphQL server. It exposes all available types
+    and directives on the server, as well as the entry points for query and mutation operations.
+  ]],
+
+  fields = function()
+    return {
+      types = {
+        description = 'A list of all types supported by this server.',
+        kind = types.nonNull(types.list(types.nonNull(__Type))),
+        resolve = function(schema)
+          return util.values(schema:getTypeMap())
+        end
+      },
+
+      queryType = {
+        description = 'The type that query operations will be rooted at.',
+        kind = __Type.nonNull,
+        resolve = function(schema)
+          return schema:getQueryType()
+        end
+      },
+
+      mutationType = {
+        description = 'If this server supports mutation, the type that mutation operations will be rooted at.',
+        kind = __Type,
+        resolve = function(schema)
+          return schema:getMutationType()
+        end
+      },
+
+      directives = {
+        description = 'A list of all directives supported by this server.',
+        kind = types.nonNull(types.list(types.nonNull(__Directive))),
+        resolve = function(schema)
+          return schema.directives
+        end
+      }
+    }
+  end
+})
+
+__Directive = types.object({
+  name = '__Directive',
+
+  description = util.trim [[
+    A Directive provides a way to describe alternate runtime execution and type validation behavior
+    in a GraphQL document.
+
+    In some cases, you need to provide options to alter GraphQL’s execution
+    behavior in ways field arguments will not suffice, such as conditionally including or skipping a
+    field. Directives provide this by describing additional information to the executor.
+  ]],
+
+  fields = function()
+    return {
+      name = types.nonNull(types.string),
+
+      description = types.string,
+
+      locations = {
+        kind = types.nonNull(types.list(types.nonNull(
+          __DirectiveLocation
+        ))),
+        resolve = function(directive)
+          local res = {}
+
+          if directive.onQuery then table.insert(res, 'QUERY') end
+          if directive.onMutation then table.insert(res, 'MUTATION') end
+          if directive.onField then table.insert(res, 'FIELD') end
+          if directive.onFragmentDefinition then table.insert(res, 'FRAGMENT_DEFINITION') end
+          if directive.onFragmentSpread then table.insert(res, 'FRAGMENT_SPREAD') end
+          if directive.onInlineFragment then table.insert(res, 'INLINE_FRAGMENT') end
+
+          return res
+        end
+      },
+
+      args = {
+        kind = types.nonNull(types.list(types.nonNull(__InputValue))),
+        resolve = function(field)
+          local args = {}
+          local transform = function(a, n)
+            if a.__type then
+              return { kind = a, name = n }
+            else
+              if a.name then return a end
+
+              local r = { name = n }
+              for k,v in pairs(a) do
+                r[k] = v
+              end
+
+              return r
+            end
+          end
+
+          for k, v in pairs(field.arguments or {}) do
+            table.insert(args, transform(v, k))
+          end
+
+          return args
+        end
+      }
+    }
+  end
+})
+
+__DirectiveLocation = types.enum({
+  name = '__DirectiveLocation',
+
+  description = util.trim [[
+    A Directive can be adjacent to many parts of the GraphQL language, a __DirectiveLocation
+    describes one such possible adjacencies.
+  ]],
+
+  values = {
+    QUERY = {
+      value = 'QUERY',
+      description = 'Location adjacent to a query operation.'
+    },
+
+    MUTATION = {
+      value = 'MUTATION',
+      description = 'Location adjacent to a mutation operation.'
+    },
+
+    FIELD = {
+      value = 'FIELD',
+      description = 'Location adjacent to a field.'
+    },
+
+    FRAGMENT_DEFINITION = {
+      value = 'FRAGMENT_DEFINITION',
+      description = 'Location adjacent to a fragment definition.'
+    },
+
+    FRAGMENT_SPREAD = {
+      value = 'FRAGMENT_SPREAD',
+      description = 'Location adjacent to a fragment spread.'
+    },
+
+    INLINE_FRAGMENT = {
+      value = 'INLINE_FRAGMENT',
+      description = 'Location adjacent to an inline fragment.'
+    }
+  }
+})
+
+__Type = types.object({
+  name = '__Type',
+
+  description = util.trim [[
+    The fundamental unit of any GraphQL Schema is the type. There are
+    many kinds of types in GraphQL as represented by the `__TypeKind` enum.
+
+    Depending on the kind of a type, certain fields describe
+    information about that type. Scalar types provide no information
+    beyond a name and description, while Enum types provide their values.
+    Object and Interface types provide the fields they describe. Abstract
+    types, Union and Interface, provide the Object types possible
+    at runtime. List and NonNull types compose other types.
+  ]],
+
+  fields = function()
+    return {
+      name = types.string,
+      description = types.string,
+
+      kind = {
+        kind = __TypeKind.nonNull,
+        resolve = function(kind)
+          if kind.__type == 'Scalar' then
+            return 'SCALAR'
+          elseif kind.__type == 'Object' then
+            return 'OBJECT'
+          elseif kind.__type == 'Interface' then
+            return 'INTERFACE'
+          elseif kind.__type == 'Union' then
+            return 'UNION'
+          elseif kind.__type == 'Enum' then
+            return 'ENUM'
+          elseif kind.__type == 'InputObject' then
+            return 'INPUT_OBJECT'
+          elseif kind.__type == 'List' then
+            return 'LIST'
+          elseif kind.__type == 'NonNull' then
+            return 'NON_NULL'
+          end
+
+          error('Unknown type ' .. kind)
+        end
+      },
+
+      fields = {
+        kind = types.list(types.nonNull(__Field)),
+        arguments = {
+          includeDeprecated = {
+            kind = types.boolean,
+            defaultValue = false
+          }
+        },
+        resolve = function(kind, arguments)
+          if kind.__type == 'Object' or kind.__type == 'Interface' then
+            return util.filter(util.values(kind.fields), function(field)
+              return arguments.includeDeprecated or field.deprecationReason == nil
+            end)
+          end
+
+          return nil
+        end
+      },
+
+      interfaces = {
+        kind = types.list(types.nonNull(__Type)),
+        resolve = function(kind)
+          if kind.__type == 'Object' then
+            return kind.interfaces
+          end
+        end
+      },
+
+      possibleTypes = {
+        kind = types.list(types.nonNull(__Type)),
+        resolve = function(kind, arguments, context)
+          if kind.__type == 'Interface' or kind.__type == 'Union' then
+            return context.schema:getPossibleTypes(kind)
+          end
+        end
+      },
+
+      enumValues = {
+        kind = types.list(types.nonNull(__EnumValue)),
+        arguments = {
+          includeDeprecated = { kind = types.boolean, defaultValue = false }
+        },
+        resolve = function(kind, arguments)
+          if kind.__type == 'Enum' then
+            return util.filter(util.values(kind.values), function(value)
+              return arguments.includeDeprecated or not value.deprecationReason
+            end)
+          end
+        end
+      },
+
+      inputFields = {
+        kind = types.list(types.nonNull(__InputValue)),
+        resolve = function(kind)
+          if kind.__type == 'InputObject' then
+            return util.values(kind.fields)
+          end
+        end
+      },
+
+      ofType = {
+        kind = __Type
+      }
+    }
+  end
+})
+
+__Field = types.object({
+  name = '__Field',
+
+  description = util.trim [[
+    Object and Interface types are described by a list of Fields, each of
+    which has a name, potentially a list of arguments, and a return type.
+  ]],
+
+  fields = function()
+    return {
+      name = types.string.nonNull,
+      description = types.string,
+
+      args = {
+        -- kind = types.list(__InputValue),
+        kind = types.nonNull(types.list(types.nonNull(__InputValue))),
+        resolve = function(field)
+          return util.map(field.arguments or {}, function(a, n)
+            if a.__type then
+              return { kind = a, name = n }
+            else
+              if not a.name then
+                local r = { name = n }
+
+                for k,v in pairs(a) do
+                  r[k] = v
+                end
+
+                return r
+              else
+                return a
+              end
+            end
+          end)
+        end
+      },
+
+      type = {
+        kind = __Type.nonNull,
+        resolve = function(field)
+          return field.kind
+        end
+      },
+
+      isDeprecated = {
+        kind = types.boolean.nonNull,
+        resolve = function(field)
+          return field.deprecationReason ~= nil
+        end
+      },
+
+      deprecationReason = types.string
+    }
+  end
+})
+
+__InputValue = types.object({
+  name = '__InputValue',
+
+  description = util.trim [[
+    Arguments provided to Fields or Directives and the input fields of an
+    InputObject are represented as Input Values which describe their type
+    and optionally a default value.
+  ]],
+
+  fields = function()
+    return {
+      name = types.string.nonNull,
+      description = types.string,
+
+      type = {
+        kind = types.nonNull(__Type),
+        resolve = function(field)
+          return field.kind
+        end
+      },
+
+      defaultValue = {
+        kind = types.string,
+        description = 'A GraphQL-formatted string representing the default value for this input value.',
+        resolve = function(inputVal)
+          return inputVal.defaultValue and tostring(inputVal.defaultValue) -- TODO improve serialization a lot
+        end
+      }
+    }
+  end
+})
+
+__EnumValue = types.object({
+  name = '__EnumValue',
+
+  description = [[
+    One possible value for a given Enum. Enum values are unique values, not
+    a placeholder for a string or numeric value. However an Enum value is
+    returned in a JSON response as a string.
+  ]],
+
+  fields = function()
+    return {
+      name = types.string.nonNull,
+      description = types.string,
+      isDeprecated = {
+        kind = types.boolean.nonNull,
+        resolve = function(enumValue) return enumValue.deprecationReason ~= nil end
+      },
+      deprecationReason = types.string
+    }
+  end
+})
+
+__TypeKind = types.enum({
+  name = '__TypeKind',
+  description = 'An enum describing what kind of type a given `__Type` is.',
+  values = {
+    SCALAR = {
+      value = 'SCALAR',
+      description = 'Indicates this type is a scalar.'
+    },
+
+    OBJECT = {
+      value = 'OBJECT',
+      description = 'Indicates this type is an object. `fields` and `interfaces` are valid fields.'
+    },
+
+    INTERFACE = {
+      value = 'INTERFACE',
+      description = 'Indicates this type is an interface. `fields` and `possibleTypes` are valid fields.'
+    },
+
+    UNION = {
+      value = 'UNION',
+      description = 'Indicates this type is a union. `possibleTypes` is a valid field.'
+    },
+
+    ENUM = {
+      value = 'ENUM',
+      description = 'Indicates this type is an enum. `enumValues` is a valid field.'
+    },
+
+    INPUT_OBJECT = {
+      value = 'INPUT_OBJECT',
+      description = 'Indicates this type is an input object. `inputFields` is a valid field.'
+    },
+
+    LIST = {
+      value = 'LIST',
+      description = 'Indicates this type is a list. `ofType` is a valid field.'
+    },
+
+    NON_NULL = {
+      value = 'NON_NULL',
+      description = 'Indicates this type is a non-null. `ofType` is a valid field.'
+    }
+  }
+})
+
+local Schema = {
+  name = '__schema',
+  kind = __Schema.nonNull,
+  description = 'Access the current type schema of this server.',
+  arguments = {},
+  resolve = function(_, _, info)
+    return info.schema
+  end
+}
+
+local Type = {
+  name = '__type',
+  kind = __Type,
+  description = 'Request the type information of a single type.',
+  arguments = {
+    name = types.string.nonNull
+  },
+  resolve = function(_, arguments, info)
+    return info.schema:getType(arguments.name)
+  end
+}
+
+local TypeName = {
+  name = '__typename',
+  kind = types.string.nonNull,
+  description = 'The name of the current Object type at runtime.',
+  arguments = {},
+  resolve = function(_, _, info)
+    return info.parentType.name
+  end
+}
+
+return {
+  __Schema = __Schema,
+  __Directive = __Directive,
+  __DirectiveLocation = __DirectiveLocation,
+  __Type = __Type,
+  __Field = __Field,
+  __EnumValue = __EnumValue,
+  __TypeKind = __TypeKind,
+  Schema = Schema,
+  Type = Type,
+  TypeName = TypeName,
+  fieldMap = {
+    __schema = Schema,
+    __type = Type,
+    __typename = TypeName
+  }
+}

+ 1 - 1
graphql/parse.lua

@@ -237,7 +237,7 @@ local function cDirective(name, arguments)
 end
 end
 
 
 -- Simple types
 -- Simple types
-local rawName = R('az', 'AZ') * (P'_' + R'09' + R('az', 'AZ')) ^ 0
+local rawName = (P'_' + R('az', 'AZ')) * (P'_' + R'09' + R('az', 'AZ')) ^ 0
 local name = rawName / cName
 local name = rawName / cName
 local fragmentName = (rawName - ('on' * -rawName)) / cName
 local fragmentName = (rawName - ('on' * -rawName)) / cName
 local alias = ws * name * P':' * ws / cAlias
 local alias = ws * name * P':' * ws / cAlias

+ 26 - 8
graphql/rules.lua

@@ -1,6 +1,22 @@
 local path = (...):gsub('%.[^%.]+$', '')
 local path = (...):gsub('%.[^%.]+$', '')
 local types = require(path .. '.types')
 local types = require(path .. '.types')
 local util = require(path .. '.util')
 local util = require(path .. '.util')
+local schema = require(path .. '.schema')
+local introspection = require(path .. '.introspection')
+
+local function getParentField(context, name, count)
+  if introspection.fieldMap[name] then return introspection.fieldMap[name] end
+
+  count = count or 1
+  local parent = context.objects[#context.objects - count]
+
+  -- Unwrap lists and non-null types
+  while parent.ofType do
+    parent = parent.ofType
+  end
+
+  return parent.fields[name]
+end
 
 
 local rules = {}
 local rules = {}
 
 
@@ -31,16 +47,14 @@ end
 function rules.fieldsDefinedOnType(node, context)
 function rules.fieldsDefinedOnType(node, context)
   if context.objects[#context.objects] == false then
   if context.objects[#context.objects] == false then
     local parent = context.objects[#context.objects - 1]
     local parent = context.objects[#context.objects - 1]
-    if(parent.__type == 'List') then
-      parent = parent.ofType
-    end
+    while parent.ofType do parent = parent.ofType end
     error('Field "' .. node.name.value .. '" is not defined on type "' .. parent.name .. '"')
     error('Field "' .. node.name.value .. '" is not defined on type "' .. parent.name .. '"')
   end
   end
 end
 end
 
 
 function rules.argumentsDefinedOnType(node, context)
 function rules.argumentsDefinedOnType(node, context)
   if node.arguments then
   if node.arguments then
-    local parentField = util.getParentField(context, node.name.value)
+    local parentField = getParentField(context, node.name.value)
     for _, argument in pairs(node.arguments) do
     for _, argument in pairs(node.arguments) do
       local name = argument.name.value
       local name = argument.name.value
       if not parentField.arguments[name] then
       if not parentField.arguments[name] then
@@ -178,18 +192,18 @@ end
 
 
 function rules.argumentsOfCorrectType(node, context)
 function rules.argumentsOfCorrectType(node, context)
   if node.arguments then
   if node.arguments then
-    local parentField = util.getParentField(context, node.name.value)
+    local parentField = getParentField(context, node.name.value)
     for _, argument in pairs(node.arguments) do
     for _, argument in pairs(node.arguments) do
       local name = argument.name.value
       local name = argument.name.value
       local argumentType = parentField.arguments[name]
       local argumentType = parentField.arguments[name]
-      util.coerceValue(argument.value, argumentType)
+      util.coerceValue(argument.value, argumentType.kind or argumentType)
     end
     end
   end
   end
 end
 end
 
 
 function rules.requiredArgumentsPresent(node, context)
 function rules.requiredArgumentsPresent(node, context)
   local arguments = node.arguments or {}
   local arguments = node.arguments or {}
-  local parentField = util.getParentField(context, node.name.value)
+  local parentField = getParentField(context, node.name.value)
   for name, argument in pairs(parentField.arguments) do
   for name, argument in pairs(parentField.arguments) do
     if argument.__type == 'NonNull' then
     if argument.__type == 'NonNull' then
       local present = util.find(arguments, function(argument)
       local present = util.find(arguments, function(argument)
@@ -275,7 +289,9 @@ end
 
 
 function rules.fragmentSpreadIsPossible(node, context)
 function rules.fragmentSpreadIsPossible(node, context)
   local fragment = node.kind == 'inlineFragment' and node or context.fragmentMap[node.name.value]
   local fragment = node.kind == 'inlineFragment' and node or context.fragmentMap[node.name.value]
+
   local parentType = context.objects[#context.objects - 1]
   local parentType = context.objects[#context.objects - 1]
+  while parentType.ofType do parentType = parentType.ofType end
 
 
   local fragmentType
   local fragmentType
   if node.kind == 'inlineFragment' then
   if node.kind == 'inlineFragment' then
@@ -298,6 +314,8 @@ function rules.fragmentSpreadIsPossible(node, context)
         types[kind.types[i]] = kind.types[i]
         types[kind.types[i]] = kind.types[i]
       end
       end
       return types
       return types
+    else
+      return {}
     end
     end
   end
   end
 
 
@@ -448,7 +466,7 @@ function rules.variableUsageAllowed(node, context)
     if not arguments then return end
     if not arguments then return end
 
 
     for field in pairs(arguments) do
     for field in pairs(arguments) do
-      local parentField = util.getParentField(context, field)
+      local parentField = getParentField(context, field)
       for i = 1, #arguments[field] do
       for i = 1, #arguments[field] do
         local argument = arguments[field][i]
         local argument = arguments[field][i]
         if argument.value.kind == 'variable' then
         if argument.value.kind == 'variable' then

+ 64 - 43
graphql/schema.lua

@@ -1,77 +1,78 @@
 local path = (...):gsub('%.[^%.]+$', '')
 local path = (...):gsub('%.[^%.]+$', '')
 local types = require(path .. '.types')
 local types = require(path .. '.types')
+local introspection = require(path .. '.introspection')
 
 
 local schema = {}
 local schema = {}
 schema.__index = schema
 schema.__index = schema
 
 
 function schema.create(config)
 function schema.create(config)
   assert(type(config.query) == 'table', 'must provide query object')
   assert(type(config.query) == 'table', 'must provide query object')
+  assert(not config.mutation or type(config.mutation) == 'table', 'mutation must be a table if provided')
+
+  local self = setmetatable({}, schema)
 
 
-  local self = {}
   for k, v in pairs(config) do
   for k, v in pairs(config) do
     self[k] = v
     self[k] = v
   end
   end
 
 
-  self.typeMap = {
-    Int = types.int,
-    Float = types.float,
-    String = types.string,
-    Boolean = types.boolean,
-    ID = types.id
+  self.directives = self.directives or {
+    types.include,
+    types.skip
   }
   }
 
 
+  self.typeMap = {}
   self.interfaceMap = {}
   self.interfaceMap = {}
   self.directiveMap = {}
   self.directiveMap = {}
 
 
-  local function generateTypeMap(node)
-    if self.typeMap[node.name] and self.typeMap[node.name] == node then return end
+  self:generateTypeMap(self.query)
+  self:generateTypeMap(self.mutation)
+  self:generateTypeMap(introspection.__Schema)
+  self:generateDirectiveMap()
 
 
-    if node.__type == 'NonNull' or node.__type == 'List' then
-      return generateTypeMap(node.ofType)
-    end
+  return self
+end
 
 
-    if self.typeMap[node.name] and self.typeMap[node.name] ~= node then
-      error('Encountered multiple types named "' .. node.name .. '"')
-    end
+function schema:generateTypeMap(node)
+  if not node or (self.typeMap[node.name] and self.typeMap[node.name] == node) then return end
 
 
-    self.typeMap[node.name] = node
+  if node.__type == 'NonNull' or node.__type == 'List' then
+    return self:generateTypeMap(node.ofType)
+  end
 
 
-    if node.__type == 'Object' and node.interfaces then
-      for _, interface in ipairs(node.interfaces) do
-        generateTypeMap(interface)
-        self.interfaceMap[interface.name] = self.interfaceMap[interface.name] or {}
-        self.interfaceMap[interface.name][node] = node
-      end
-    end
+  if self.typeMap[node.name] and self.typeMap[node.name] ~= node then
+    error('Encountered multiple types named "' .. node.name .. '"')
+  end
 
 
-    if node.__type == 'Object' or node.__type == 'Interface' or node.__type == 'InputObject' then
-      if type(node.fields) == 'function' then node.fields = node.fields() end
-      for fieldName, field in pairs(node.fields) do
-        if field.arguments then
-          for _, argument in pairs(field.arguments) do
-            generateTypeMap(argument)
-          end
-        end
+  node.fields = type(node.fields) == 'function' and node.fields() or node.fields
+  self.typeMap[node.name] = node
 
 
-        generateTypeMap(field.kind)
-      end
+  if node.__type == 'Object' and node.interfaces then
+    for _, interface in ipairs(node.interfaces) do
+      self:generateTypeMap(interface)
+      self.interfaceMap[interface.name] = self.interfaceMap[interface.name] or {}
+      self.interfaceMap[interface.name][node] = node
     end
     end
   end
   end
 
 
-  generateTypeMap(self.query)
-
-  self.directives = self.directives or {
-    types.include,
-    types.skip
-  }
+  if node.__type == 'Object' or node.__type == 'Interface' or node.__type == 'InputObject' then
+    for fieldName, field in pairs(node.fields) do
+      if field.arguments then
+        for name, argument in pairs(field.arguments) do
+          local argumentType = argument.__type and argument or argument.kind
+          assert(argumentType, 'Must supply type for argument "' .. name .. '" on "' .. fieldName .. '"')
+          self:generateTypeMap(argumentType)
+        end
+      end
 
 
-  if self.directives then
-    for _, directive in ipairs(self.directives) do
-      self.directiveMap[directive.name] = directive
+      self:generateTypeMap(field.kind)
     end
     end
   end
   end
+end
 
 
-  return setmetatable(self, schema)
+function schema:generateDirectiveMap()
+  for _, directive in ipairs(self.directives) do
+    self.directiveMap[directive.name] = directive
+  end
 end
 end
 
 
 function schema:getType(name)
 function schema:getType(name)
@@ -90,4 +91,24 @@ function schema:getDirective(name)
   return self.directiveMap[name]
   return self.directiveMap[name]
 end
 end
 
 
+function schema:getQueryType()
+  return self.query
+end
+
+function schema:getMutationType()
+  return self.mutation
+end
+
+function schema:getTypeMap()
+  return self.typeMap
+end
+
+function schema:getPossibleTypes(abstractType)
+  if abstractType.__type == 'Union' then
+    return abstractType.types
+  end
+
+  return self:getImplementors(abstractType)
+end
+
 return schema
 return schema

+ 41 - 13
graphql/types.lua

@@ -61,6 +61,7 @@ function types.object(config)
   local instance = {
   local instance = {
     __type = 'Object',
     __type = 'Object',
     name = config.name,
     name = config.name,
+    description = config.description,
     isTypeOf = config.isTypeOf,
     isTypeOf = config.isTypeOf,
     fields = fields,
     fields = fields,
     interfaces = config.interfaces
     interfaces = config.interfaces
@@ -108,6 +109,8 @@ function initFields(kind, fields)
     result[fieldName] = {
     result[fieldName] = {
       name = fieldName,
       name = fieldName,
       kind = field.kind,
       kind = field.kind,
+      description = field.description,
+      deprecationReason = field.deprecationReason,
       arguments = field.arguments or {},
       arguments = field.arguments or {},
       resolve = kind == 'Object' and field.resolve or nil
       resolve = kind == 'Object' and field.resolve or nil
     }
     }
@@ -120,11 +123,28 @@ function types.enum(config)
   assert(type(config.name) == 'string', 'type name must be provided as a string')
   assert(type(config.name) == 'string', 'type name must be provided as a string')
   assert(type(config.values) == 'table', 'values table must be provided')
   assert(type(config.values) == 'table', 'values table must be provided')
 
 
-  local instance = {
+  local instance
+  local values = {}
+
+  for name, entry in pairs(config.values) do
+    entry = type(entry) == 'table' and entry or { value = entry }
+
+    values[name] = {
+      name = name,
+      description = entry.description,
+      deprecationReason = entry.deprecationReason,
+      value = entry.value
+    }
+  end
+
+  instance = {
     __type = 'Enum',
     __type = 'Enum',
     name = config.name,
     name = config.name,
     description = config.description,
     description = config.description,
-    values = config.values
+    values = values,
+    serialize = function(name)
+      return instance.values[name] and instance.values[name].value or name
+    end
   }
   }
 
 
   instance.nonNull = types.nonNull(instance)
   instance.nonNull = types.nonNull(instance)
@@ -181,6 +201,7 @@ end
 
 
 types.int = types.scalar({
 types.int = types.scalar({
   name = 'Int',
   name = 'Int',
+  description = "The `Int` scalar type represents non-fractional signed whole numeric values. Int can represent values between -(2^31) and 2^31 - 1. ", 
   serialize = coerceInt,
   serialize = coerceInt,
   parseValue = coerceInt,
   parseValue = coerceInt,
   parseLiteral = function(node)
   parseLiteral = function(node)
@@ -203,6 +224,7 @@ types.float = types.scalar({
 
 
 types.string = types.scalar({
 types.string = types.scalar({
   name = 'String',
   name = 'String',
+  description = "The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text.",
   serialize = tostring,
   serialize = tostring,
   parseValue = tostring,
   parseValue = tostring,
   parseLiteral = function(node)
   parseLiteral = function(node)
@@ -218,6 +240,7 @@ end
 
 
 types.boolean = types.scalar({
 types.boolean = types.scalar({
   name = 'Boolean',
   name = 'Boolean',
+  description = "The `Boolean` scalar type represents `true` or `false`.",
   serialize = toboolean,
   serialize = toboolean,
   parseValue = toboolean,
   parseValue = toboolean,
   parseLiteral = function(node)
   parseLiteral = function(node)
@@ -246,9 +269,12 @@ function types.directive(config)
     name = config.name,
     name = config.name,
     description = config.description,
     description = config.description,
     arguments = config.arguments,
     arguments = config.arguments,
-    onOperation = config.onOperation or false,
-    onFragment = config.onOperation or false,
-    onField = config.onField or false
+    onQuery = config.onQuery,
+    onMutation = config.onMutation,
+    onField = config.onField,
+    onFragmentDefinition = config.onFragmentDefinition,
+    onFragmentSpread = config.onFragmentSpread,
+    onInlineFragment = config.onInlineFragment
   }
   }
 
 
   return instance
   return instance
@@ -256,22 +282,24 @@ end
 
 
 types.include = types.directive({
 types.include = types.directive({
   name = 'include',
   name = 'include',
+  description = 'Directs the executor to include this field or fragment only when the `if` argument is true.',
   arguments = {
   arguments = {
-    ['if'] = types.boolean.nonNull
+    ['if'] = { kind = types.boolean.nonNull, description = 'Included when true.'}
   },
   },
-  onOperation = false,
-  onFragment = true,
-  onField = true
+  onField = true,
+  onFragmentSpread = true,
+  onInlineFragment = true
 })
 })
 
 
 types.skip = types.directive({
 types.skip = types.directive({
   name = 'skip',
   name = 'skip',
+  description = 'Directs the executor to skip this field or fragment when the `if` argument is true.',
   arguments = {
   arguments = {
-    ['if'] = types.boolean.nonNull
+    ['if'] = { kind = types.boolean.nonNull, description = 'Skipped when true.' }
   },
   },
-  onOperation = false,
-  onFragment = true,
-  onField = true
+  onField = true,
+  onFragmentSpread = true,
+  onInlineFragment = true
 })
 })
 
 
 return types
 return types

+ 20 - 8
graphql/util.lua

@@ -13,6 +13,24 @@ function util.find(t, fn)
   end
   end
 end
 end
 
 
+function util.filter(t, fn)
+  local res = {}
+  for k,v in pairs(t) do
+    if fn(v) then
+      table.insert(res, v)
+    end
+  end
+  return res
+end
+
+function util.values(t)
+  local res = {}
+  for _, value in pairs(t) do
+    table.insert(res, value)
+  end
+  return res
+end
+
 function util.compose(f, g)
 function util.compose(f, g)
   return function(...) return f(g(...)) end
   return function(...) return f(g(...)) end
 end
 end
@@ -23,14 +41,8 @@ function util.bind1(func, x)
   end
   end
 end
 end
 
 
-function util.getParentField(context, name, count)
-  count = count == nil and 1 or count
-  local obj = context.objects[#context.objects - count]
-  if obj.__type == 'List' then
-    return obj.ofType.fields[name]
-  else
-    return obj.fields[name]
-  end
+function util.trim(s)
+  return s:gsub('^%s+', ''):gsub('%s$', ''):gsub('%s%s+', ' ')
 end
 end
 
 
 function util.coerceValue(node, schemaType, variables)
 function util.coerceValue(node, schemaType, variables)

+ 40 - 9
graphql/validate.lua

@@ -1,6 +1,22 @@
 local path = (...):gsub('%.[^%.]+$', '')
 local path = (...):gsub('%.[^%.]+$', '')
 local rules = require(path .. '.rules')
 local rules = require(path .. '.rules')
 local util = require(path .. '.util')
 local util = require(path .. '.util')
+local introspection = require(path .. '.introspection')
+local schema = require(path .. '.schema')
+
+local function getParentField(context, name, count)
+  if introspection.fieldMap[name] then return introspection.fieldMap[name] end
+
+  count = count or 1
+  local parent = context.objects[#context.objects - count]
+
+  -- Unwrap lists and non-null types
+  while parent.ofType do
+    parent = parent.ofType
+  end
+
+  return parent.fields[name]
+end
 
 
 local visitors = {
 local visitors = {
   document = {
   document = {
@@ -21,7 +37,7 @@ local visitors = {
 
 
   operation = {
   operation = {
     enter = function(node, context)
     enter = function(node, context)
-      table.insert(context.objects, context.schema.query)
+      table.insert(context.objects, context.schema[node.operation])
       context.currentOperation = node
       context.currentOperation = node
       context.variableReferences = {}
       context.variableReferences = {}
     end,
     end,
@@ -59,11 +75,15 @@ local visitors = {
 
 
   field = {
   field = {
     enter = function(node, context)
     enter = function(node, context)
-      local parentField = util.getParentField(context, node.name.value, 0)
-
-      -- false is a special value indicating that the field was not present in the type definition.
-      local field = parentField and parentField.kind or false
-      table.insert(context.objects, field)
+      local name = node.name.value
+
+      if introspection.fieldMap[name] then
+        table.insert(context.objects, introspection.fieldMap[name].kind)
+      else
+        local parentField = getParentField(context, name, 0)
+        -- false is a special value indicating that the field was not present in the type definition.
+        table.insert(context.objects, parentField and parentField.kind or false)
+      end
     end,
     end,
 
 
     exit = function(node, context)
     exit = function(node, context)
@@ -157,9 +177,15 @@ local visitors = {
                 collectTransitiveVariables(selection)
                 collectTransitiveVariables(selection)
               end
               end
             end
             end
-          elseif referencedNode.kind == 'field' and referencedNode.arguments then
-            for _, argument in ipairs(referencedNode.arguments) do
-              collectTransitiveVariables(argument)
+          elseif referencedNode.kind == 'field' then
+            if referencedNode.arguments then
+              for _, argument in ipairs(referencedNode.arguments) do
+                collectTransitiveVariables(argument)
+              end
+            end
+
+            if referencedNode.selectionSet then
+              collectTransitiveVariables(referencedNode.selectionSet)
             end
             end
           elseif referencedNode.kind == 'argument' then
           elseif referencedNode.kind == 'argument' then
             return collectTransitiveVariables(referencedNode.value)
             return collectTransitiveVariables(referencedNode.value)
@@ -171,6 +197,7 @@ local visitors = {
             return collectTransitiveVariables(referencedNode.selectionSet)
             return collectTransitiveVariables(referencedNode.selectionSet)
           elseif referencedNode.kind == 'fragmentSpread' then
           elseif referencedNode.kind == 'fragmentSpread' then
             local fragment = context.fragmentMap[referencedNode.name.value]
             local fragment = context.fragmentMap[referencedNode.name.value]
+            context.usedFragments[referencedNode.name.value] = true
             return fragment and collectTransitiveVariables(fragment.selectionSet)
             return fragment and collectTransitiveVariables(fragment.selectionSet)
           end
           end
         end
         end
@@ -179,6 +206,10 @@ local visitors = {
       end
       end
     end,
     end,
 
 
+    exit = function(node, context)
+      table.remove(context.objects)
+    end,
+
     rules = {
     rules = {
       rules.fragmentSpreadTargetDefined,
       rules.fragmentSpreadTargetDefined,
       rules.fragmentSpreadIsPossible,
       rules.fragmentSpreadIsPossible,