Browse Source

Add util;

bjorn 9 years ago
parent
commit
748f853fe9
2 changed files with 86 additions and 62 deletions
  1. 10 62
      rules.lua
  2. 76 0
      util.lua

+ 10 - 62
rules.lua

@@ -1,51 +1,7 @@
 local types = require 'types'
-local rules = {}
-
-local function checkValueMatchesType(valueNode, schemaType)
-  if schemaType.__type == 'NonNull' then
-    return checkValueMatchesType(schemaType.ofType, valueNode)
-  end
-
-  if schemaType.__type == 'List' then
-    if valueNode.kind ~= 'list' then
-      error('Expected a list')
-    end
-
-    for i = 1, #valueNode.values do
-      checkValueMatchesType(schemaType.ofType, valueNode.values[i])
-    end
-  end
-
-  if schemaType.__type == 'InputObject' then
-    if valueNode.kind ~= 'inputObject' then
-      error('Expected an input object')
-    end
-
-    for _, field in ipairs(valueNode.values) do
-      if not schemaType.fields[field.name] then
-        error('Unknown input object field "' .. field.name .. '"')
-      end
-
-      checkValueMatchesType(schemaType.fields[field.name].kind, field.value)
-    end
-  end
-
-  if schemaType.__type == 'Enum' then
-    if valueNode.kind ~= 'enum' then
-      error('Expected enum value, got ' .. valueNode.kind)
-    end
+local util = require 'util'
 
-    if not schemaType.values[valueNode.value] then
-      error('Invalid enum value "' .. valueNode.value .. '"')
-    end
-  end
-
-  if schemaType.__type == 'Scalar' then
-    if schemaType.parseLiteral(valueNode) == nil then
-      error('Could not coerce "' .. valueNode.value .. '" to "' .. schemaType.name .. '"')
-    end
-  end
-end
+local rules = {}
 
 function rules.uniqueOperationNames(node, context)
   local name = node.name and node.name.value
@@ -225,7 +181,7 @@ function rules.argumentsOfCorrectType(node, context)
     for _, argument in pairs(node.arguments) do
       local name = argument.name.value
       local argumentType = parentField.arguments[name]
-      checkValueMatchesType(argument.value, argumentType)
+      util.coerceValue(argument.value, argumentType)
     end
   end
 end
@@ -235,13 +191,9 @@ function rules.requiredArgumentsPresent(node, context)
   local parentField = context.objects[#context.objects - 1].fields[node.name.value]
   for name, argument in pairs(parentField.arguments) do
     if argument.__type == 'NonNull' then
-      local present = false
-      for i = 1, #arguments do
-        if arguments[i].name.value == name then
-          present = true
-          break
-        end
-      end
+      local present = util.find(arguments, function(argument)
+        return argument.name.value == name
+      end)
 
       if not present then
         error('Required argument "' .. name .. '" was not supplied.')
@@ -350,14 +302,10 @@ function rules.fragmentSpreadIsPossible(node, context)
 
   local parentTypes = getTypes(parentType)
   local fragmentTypes = getTypes(fragmentType)
-  local valid = false
 
-  for _, kind in pairs(parentTypes) do
-    if fragmentTypes[kind] then
-      valid = true
-      break
-    end
-  end
+  local valid = util.find(parentTypes, function(kind)
+    return fragmentTypes[kind]
+  end)
 
   if not valid then
     error('Fragment type condition is not possible for given type')
@@ -422,7 +370,7 @@ function rules.variableDefaultValuesHaveCorrectType(node, context)
       if definition.type.kind == 'nonNullType' and definition.defaultValue then
         error('Non-null variables can not have default values')
       elseif definition.defaultValue then
-        checkValueMatchesType(definition.defaultValue, context.schema:getType(definition.type.name.value))
+        util.coerceValue(definition.defaultValue, context.schema:getType(definition.type.name.value))
       end
     end
   end

+ 76 - 0
util.lua

@@ -0,0 +1,76 @@
+local util = {}
+
+function util.map(t, fn)
+  local res = {}
+  for k, v in pairs(t) do res[k] = fn(v, k) end
+  return res
+end
+
+function util.find(t, fn)
+  local res = {}
+  for k, v in pairs(t) do
+    if fn(v, k) then return v end
+  end
+end
+
+function util.coerceValue(node, schemaType, variables)
+  variables = variables or {}
+
+  if schemaType.__type == 'NonNull' then
+    return util.coerceValue(node, schemaType.ofType)
+  end
+
+  if not node then
+    return nil
+  end
+
+  if node.kind == 'variable' then
+    return variables[node.name.value]
+  end
+
+  if schemaType.__type == 'List' then
+    if node.kind ~= 'list' then
+      error('Expected a list')
+    end
+
+    return util.map(node.values, function(value)
+      return util.coerceValue(node.values[i], schemaType.ofType)
+    end)
+  end
+
+  if schemaType.__type == 'InputObject' then
+    if node.kind ~= 'inputObject' then
+      error('Expected an input object')
+    end
+
+    return util.map(node.values, function(field)
+      if not schemaType.fields[field.name] then
+        error('Unknown input object field "' .. field.name .. '"')
+      end
+
+      return util.coerceValue(schemaType.fields[field.name].kind, field.value)
+    end)
+  end
+
+  if schemaType.__type == 'Enum' then
+    if node.kind ~= 'enum' then
+      error('Expected enum value, got ' .. node.kind)
+    end
+
+    if not schemaType.values[node.value] then
+      error('Invalid enum value "' .. node.value .. '"')
+    end
+
+    return node.value
+  end
+
+  if schemaType.__type == 'Scalar' then
+    if schemaType.parseLiteral(node) == nil then
+      error('Could not coerce "' .. node.value .. '" to "' .. schemaType.name .. '"')
+    end
+
+    return schemaType.parseLiteral(node)
+  end
+end
+
+return util