Browse Source

Schema cleanup;

bjorn 9 years ago
parent
commit
c54a3c2891
2 changed files with 14 additions and 15 deletions
  1. 12 13
      graphql/schema.lua
  2. 2 2
      tests/introspection.lua

+ 12 - 13
graphql/schema.lua

@@ -7,17 +7,15 @@ schema.__index = schema
 
 function schema.create(config)
   assert(type(config.query) == 'table', 'must provide query object')
-  if config.mutation then
-    assert(type(config.mutation) == 'table', 'mutation must be a table')
-  end
-  
-  local self = {}
+  assert(not config.mutation or type(config.mutation) == 'table', 'mutation must be a table if provided')
+
+  local self = setmetatable({}, schema)
+
   for k, v in pairs(config) do
     self[k] = v
   end
 
-  self.typeMap = {
-  }
+  self.typeMap = {}
 
   self.interfaceMap = {}
   self.directiveMap = {}
@@ -76,7 +74,7 @@ function schema.create(config)
     end
   end
 
-  return setmetatable(self, schema)
+  return self
 end
 
 function schema:getType(name)
@@ -109,14 +107,15 @@ end
 
 function schema:getPossibleTypes(abstractType)
   if abstractType.__type == 'Union' then
-    return abstractType.types;
+    return abstractType.types
   end
-  return self:getImplementors(abstractType);
-end
 
+  return self:getImplementors(abstractType)
+end
 
 function schema.getParentField(context, name, count)
   local parent = nil
+
   if name == '__schema' then
     parent = introspection.SchemaMetaFieldDef
   elseif name == '__type' then
@@ -124,13 +123,13 @@ function schema.getParentField(context, name, count)
   elseif name == '__typename' then
     parent = introspection.TypeNameMetaFieldDef
   else
-    count = count == nil and 1 or count
+    count = count or 1
     local obj = context.objects[#context.objects - count]
     if obj.ofType then obj = obj.ofType end
     parent = obj.fields[name]
   end
+
   return parent
 end
 
-
 return schema

+ 2 - 2
tests/introspection.lua

@@ -1250,8 +1250,8 @@ describe('introspection', function()
   local operationName = 'IntrospectionQuery'
   local response = execute(schema, parse(introspection_query), rootValue, variables, operationName)
   local expected = cjson.decode(introspection_expected_json)
-  assert:set_parameter("TableFormatLevel", 10) 
-  local compare_by_name = function(a,b) return a.name < b.name end  
+  assert:set_parameter("TableFormatLevel", 10)
+  local compare_by_name = function(a,b) return a.name < b.name end
   table.sort(response.__schema.directives, compare_by_name)
   table.sort(expected.__schema.directives, compare_by_name)
   table.sort(response.__schema.types, compare_by_name)