Bläddra i källkod

Add *_eq operation for integer vectors

rexim 3 veckor sedan
förälder
incheckning
53a7425f4e
2 ändrade filer med 97 tillägg och 0 borttagningar
  1. 55 0
      la.h
  2. 42 0
      src/lag.c

+ 55 - 0
la.h

@@ -2,6 +2,7 @@
 #define LA_H_
 
 #include <math.h>
+#include <stdbool.h>
 
 #ifndef LADEF
 #define LADEF static inline
@@ -183,6 +184,9 @@ LADEF V2u v2u_max(V2u a, V2u b);
 LADEF V2u v2u_clamp(V2u x, V2u a, V2u b);
 LADEF unsigned int v2u_sqrlen(V2u a);
 
+LADEF bool v2i_eq(V2i a, V2i b);
+LADEF bool v2u_eq(V2u a, V2u b);
+
 #define V3f_Fmt "v3f(%f, %f, %f)"
 #define V3f_Arg(v) (v).x, (v).y, (v).z
 LADEF V3f v3f(float x, float y, float z);
@@ -299,6 +303,9 @@ LADEF V3u v3u_max(V3u a, V3u b);
 LADEF V3u v3u_clamp(V3u x, V3u a, V3u b);
 LADEF unsigned int v3u_sqrlen(V3u a);
 
+LADEF bool v3i_eq(V3i a, V3i b);
+LADEF bool v3u_eq(V3u a, V3u b);
+
 #define V4f_Fmt "v4f(%f, %f, %f, %f)"
 #define V4f_Arg(v) (v).x, (v).y, (v).z, (v).w
 LADEF V4f v4f(float x, float y, float z, float w);
@@ -415,6 +422,9 @@ LADEF V4u v4u_max(V4u a, V4u b);
 LADEF V4u v4u_clamp(V4u x, V4u a, V4u b);
 LADEF unsigned int v4u_sqrlen(V4u a);
 
+LADEF bool v4i_eq(V4i a, V4i b);
+LADEF bool v4u_eq(V4u a, V4u b);
+
 #endif // LA_H_
 
 #ifdef LA_IMPLEMENTATION
@@ -1225,6 +1235,19 @@ LADEF unsigned int v2u_sqrlen(V2u a)
     return a.x*a.x + a.y*a.y;
 }
 
+LADEF bool v2i_eq(V2i a, V2i b)
+{
+    if (a.x != b.x) return false;
+    if (a.y != b.y) return false;
+    return true;
+}
+LADEF bool v2u_eq(V2u a, V2u b)
+{
+    if (a.x != b.x) return false;
+    if (a.y != b.y) return false;
+    return true;
+}
+
 LADEF V3f v3f(float x, float y, float z)
 {
     V3f v;
@@ -2075,6 +2098,21 @@ LADEF unsigned int v3u_sqrlen(V3u a)
     return a.x*a.x + a.y*a.y + a.z*a.z;
 }
 
+LADEF bool v3i_eq(V3i a, V3i b)
+{
+    if (a.x != b.x) return false;
+    if (a.y != b.y) return false;
+    if (a.z != b.z) return false;
+    return true;
+}
+LADEF bool v3u_eq(V3u a, V3u b)
+{
+    if (a.x != b.x) return false;
+    if (a.y != b.y) return false;
+    if (a.z != b.z) return false;
+    return true;
+}
+
 LADEF V4f v4f(float x, float y, float z, float w)
 {
     V4f v;
@@ -3019,4 +3057,21 @@ LADEF unsigned int v4u_sqrlen(V4u a)
     return a.x*a.x + a.y*a.y + a.z*a.z + a.w*a.w;
 }
 
+LADEF bool v4i_eq(V4i a, V4i b)
+{
+    if (a.x != b.x) return false;
+    if (a.y != b.y) return false;
+    if (a.z != b.z) return false;
+    if (a.w != b.w) return false;
+    return true;
+}
+LADEF bool v4u_eq(V4u a, V4u b)
+{
+    if (a.x != b.x) return false;
+    if (a.y != b.y) return false;
+    if (a.z != b.z) return false;
+    if (a.w != b.w) return false;
+    return true;
+}
+
 #endif // LA_IMPLEMENTATION

+ 42 - 0
src/lag.c

@@ -164,6 +164,35 @@ void gen_vector_scalar_ctor(FILE *stream, Stmt stmt, size_t n, Type_Def type_def
     }
 }
 
+void gen_vector_eq(FILE *stream, Stmt stmt, size_t n, Type type)
+{
+    Type_Def type_def = type_defs[type];
+    const char *vector_type = make_vector_type(n, type_def);
+    const char *vector_prefix = make_vector_prefix(n, type_def);
+    const char *name = temp_sprintf("%s_eq", vector_prefix);
+
+    static_assert(OP_ARITY >= 2, "This code assumes that operation's arity is at least 2");
+    gen_func_sig(stream, "bool", name, vector_type, op_arg_names, 2);
+    switch (stmt) {
+    case STMT_DECL: {
+        fprintf(stream, ";\n");
+    } break;
+    case STMT_IMPL: {
+        fprintf(stream, "\n");
+        fprintf(stream, "{\n");
+        assert(n <= VECTOR_MAX_SIZE);
+        for (size_t i = 0; i < n; ++i) {
+            fprintf(stream, "    if (%s.%s != %s.%s) return false;\n",
+                    op_arg_names[0], vector_comps[i],
+                    op_arg_names[1], vector_comps[i]);
+        }
+        fprintf(stream, "    return true;\n");
+        fprintf(stream, "}\n");
+    } break;
+    default: UNREACHABLE(temp_sprintf("invalid stmt: %d", stmt));
+    }
+}
+
 void gen_vector_op(FILE *stream, Stmt stmt, size_t n, Type type, Op_Type op_type)
 {
     Type_Def type_def = type_defs[type];
@@ -602,6 +631,7 @@ int main()
         fprintf(stream, "#define LA_H_\n");
         fprintf(stream, "\n");
         fprintf(stream, "#include <math.h>\n");
+        fprintf(stream, "#include <stdbool.h>\n");
         fprintf(stream, "\n");
         fprintf(stream, "#ifndef LADEF\n");
         fprintf(stream, "#define LADEF static inline\n");
@@ -640,6 +670,7 @@ int main()
                 for (Op_Type op = 0; op < COUNT_OPS; ++op) {
                     gen_vector_op(stream, STMT_DECL, n, type, op);
                 }
+
                 for (Fun_Type fun = 0; fun < COUNT_FUNS; ++fun) {
                     if (fun_defs[fun].name_for_type[type]) {
                         gen_vector_fun(stream, STMT_DECL, n, type, fun);
@@ -651,6 +682,11 @@ int main()
                 }
                 fprintf(stream, "\n");
             }
+
+            static_assert(COUNT_TYPES == 4, "Amount of types has changed");
+            gen_vector_eq(stream, STMT_DECL, n, TYPE_INT);
+            gen_vector_eq(stream, STMT_DECL, n, TYPE_UNSIGNED_INT);
+            fprintf(stream, "\n");
         }
 
         fprintf(stream, "#endif // LA_H_\n");
@@ -696,6 +732,7 @@ int main()
                     gen_vector_op(stream, STMT_IMPL, n, type, op);
                     fputc('\n', stream);
                 }
+
                 for (Fun_Type fun = 0; fun < COUNT_FUNS; ++fun) {
                     if (fun_defs[fun].name_for_type[type]) {
                         gen_vector_fun(stream, STMT_IMPL, n, type, fun);
@@ -709,6 +746,11 @@ int main()
                     fputc('\n', stream);
                 }
             }
+
+            static_assert(COUNT_TYPES == 4, "Amount of types has changed");
+            gen_vector_eq(stream, STMT_IMPL, n, TYPE_INT);
+            gen_vector_eq(stream, STMT_IMPL, n, TYPE_UNSIGNED_INT);
+            fputc('\n', stream);
         }
         fprintf(stream, "#endif // LA_IMPLEMENTATION\n");
     }