Browse Source

Add tests for & fix `pk_oid` API

Signed-off-by: Steffen Jaeckel <[email protected]>
Steffen Jaeckel 1 month ago
parent
commit
07d544c5d9
4 changed files with 92 additions and 26 deletions
  1. 49 26
      src/pk/asn1/oid/pk_oid_str.c
  2. 1 0
      tests/misc_test.c
  3. 41 0
      tests/pk_oid_test.c
  4. 1 0
      tests/tomcrypt_test.h

+ 49 - 26
src/pk/asn1/oid/pk_oid_str.c

@@ -6,20 +6,18 @@
 int pk_oid_str_to_num(const char *OID, unsigned long *oid, unsigned long *oidlen)
 int pk_oid_str_to_num(const char *OID, unsigned long *oid, unsigned long *oidlen)
 {
 {
    unsigned long i, j, limit, oid_j;
    unsigned long i, j, limit, oid_j;
-   size_t OID_len;
 
 
    LTC_ARGCHK(oidlen != NULL);
    LTC_ARGCHK(oidlen != NULL);
 
 
    limit = *oidlen;
    limit = *oidlen;
    *oidlen = 0; /* make sure that we return zero oidlen on error */
    *oidlen = 0; /* make sure that we return zero oidlen on error */
-   for (i = 0; i < limit; i++) oid[i] = 0;
-
+   if (oid != NULL) {
+      XMEMSET(oid, 0, sizeof(*oid) * limit);
+   }
    if (OID == NULL) return CRYPT_OK;
    if (OID == NULL) return CRYPT_OK;
+   if (OID[0] == '\0') return CRYPT_OK;
 
 
-   OID_len = XSTRLEN(OID);
-   if (OID_len == 0) return CRYPT_OK;
-
-   for (i = 0, j = 0; i < OID_len; i++) {
+   for (i = 0, j = 0; OID[i] != '\0'; i++) {
       if (OID[i] == '.') {
       if (OID[i] == '.') {
          if (++j >= limit) continue;
          if (++j >= limit) continue;
       }
       }
@@ -34,49 +32,74 @@ int pk_oid_str_to_num(const char *OID, unsigned long *oid, unsigned long *oidlen
       }
       }
    }
    }
    if (j == 0) return CRYPT_ERROR;
    if (j == 0) return CRYPT_ERROR;
-   if (j >= limit) {
-      *oidlen = j;
+   *oidlen = j + 1;
+   if (j >= limit || oid == NULL) {
       return CRYPT_BUFFER_OVERFLOW;
       return CRYPT_BUFFER_OVERFLOW;
    }
    }
-   *oidlen = j + 1;
    return CRYPT_OK;
    return CRYPT_OK;
 }
 }
 
 
+typedef struct num_to_str {
+   int err;
+   char *wr;
+   unsigned long max_len, res_len;
+} num_to_str;
+
+static LTC_INLINE void s_wr(char c, num_to_str *w)
+{
+   if (w->res_len == ULONG_MAX) {
+      w->err = CRYPT_OVERFLOW;
+      return;
+   }
+   w->res_len++;
+   if (w->res_len > w->max_len) w->wr = NULL;
+   if (w->wr) w->wr[w->max_len - w->res_len] = c;
+}
+
 int pk_oid_num_to_str(const unsigned long *oid, unsigned long oidlen, char *OID, unsigned long *outlen)
 int pk_oid_num_to_str(const unsigned long *oid, unsigned long oidlen, char *OID, unsigned long *outlen)
 {
 {
    int i;
    int i;
-   unsigned long j, k;
-   char tmp[LTC_OID_MAX_STRLEN] = { 0 };
+   num_to_str w;
+   unsigned long j;
 
 
    LTC_ARGCHK(oid != NULL);
    LTC_ARGCHK(oid != NULL);
    LTC_ARGCHK(oidlen < INT_MAX);
    LTC_ARGCHK(oidlen < INT_MAX);
    LTC_ARGCHK(outlen != NULL);
    LTC_ARGCHK(outlen != NULL);
 
 
-   for (i = oidlen - 1, k = 0; i >= 0; i--) {
+   if (OID == NULL || *outlen == 0) {
+      w.max_len = ULONG_MAX;
+      w.wr = NULL;
+   } else {
+      w.max_len = *outlen;
+      w.wr = OID;
+   }
+   w.res_len = 0;
+   w.err = CRYPT_OK;
+
+   s_wr('\0', &w);
+   for (i = oidlen; i --> 0;) {
       j = oid[i];
       j = oid[i];
       if (j == 0) {
       if (j == 0) {
-         tmp[k] = '0';
-         if (++k >= sizeof(tmp)) return CRYPT_ERROR;
-      }
-      else {
+         s_wr('0', &w);
+      } else {
          while (j > 0) {
          while (j > 0) {
-            tmp[k] = '0' + (j % 10);
-            if (++k >= sizeof(tmp)) return CRYPT_ERROR;
+            s_wr('0' + (j % 10), &w);
             j /= 10;
             j /= 10;
          }
          }
       }
       }
       if (i > 0) {
       if (i > 0) {
-        tmp[k] = '.';
-        if (++k >= sizeof(tmp)) return CRYPT_ERROR;
+         s_wr('.', &w);
       }
       }
    }
    }
-   if (*outlen < k + 1) {
-      *outlen = k + 1;
+   if (w.err != CRYPT_OK) {
+      return w.err;
+   }
+   if (*outlen < w.res_len || OID == NULL) {
+      *outlen = w.res_len;
       return CRYPT_BUFFER_OVERFLOW;
       return CRYPT_BUFFER_OVERFLOW;
    }
    }
    LTC_ARGCHK(OID != NULL);
    LTC_ARGCHK(OID != NULL);
-   for (j = 0; j < k; j++) OID[j] = tmp[k - j - 1];
-   OID[k] = '\0';
-   *outlen = k; /* the length without terminating NUL byte */
+   XMEMMOVE(OID, OID + (w.max_len - w.res_len), w.res_len);
+   *outlen = w.res_len;
    return CRYPT_OK;
    return CRYPT_OK;
 }
 }

+ 1 - 0
tests/misc_test.c

@@ -34,6 +34,7 @@ int misc_test(void)
 #ifdef LTC_SSH
 #ifdef LTC_SSH
    ssh_test();
    ssh_test();
 #endif
 #endif
+   pk_oid_test();
    no_null_termination_check_test();
    no_null_termination_check_test();
    return 0;
    return 0;
 }
 }

+ 41 - 0
tests/pk_oid_test.c

@@ -0,0 +1,41 @@
+/* LibTomCrypt, modular cryptographic library -- Tom St Denis */
+/* SPDX-License-Identifier: Unlicense */
+
+#include  <tomcrypt_test.h>
+
+int pk_oid_test(void)
+{
+   const char *oid_str = "1.2.3.4.5";
+   const unsigned long oid_ul[] = { 1, 2, 3, 4, 5 };
+   char str[16];
+   unsigned long buf[6], num = LTC_ARRAY_SIZE(oid_ul), strlen = sizeof(str), should_size = 0;
+
+   SHOULD_FAIL_WITH(pk_oid_str_to_num(oid_str, NULL, &should_size), CRYPT_BUFFER_OVERFLOW);
+   ENSURE(should_size == 5);
+
+   DO(pk_oid_str_to_num(oid_str, buf, &num));
+   ENSURE(num == 5);
+
+   should_size = 1;
+   SHOULD_FAIL_WITH(pk_oid_num_to_str(oid_ul, 5, str, &should_size), CRYPT_BUFFER_OVERFLOW);
+   ENSURE(should_size == 10);
+   should_size = 1;
+   SHOULD_FAIL_WITH(pk_oid_num_to_str(oid_ul, 5, NULL, &should_size), CRYPT_BUFFER_OVERFLOW);
+   ENSURE(should_size == 10);
+   should_size = 16;
+   SHOULD_FAIL_WITH(pk_oid_num_to_str(oid_ul, 5, NULL, &should_size), CRYPT_BUFFER_OVERFLOW);
+   ENSURE(should_size == 10);
+
+   XMEMSET(str, 'a', sizeof(str));
+   DO(pk_oid_num_to_str(oid_ul, 5, str, &strlen));
+   ENSURE(strlen == 10);
+   ENSURE(XMEMCMP(str, oid_str, strlen) == 0);
+
+   XMEMSET(str, 'a', sizeof(str));
+   strlen = 10;
+   DO(pk_oid_num_to_str(oid_ul, 5, str, &strlen));
+   ENSURE(strlen == 10);
+   ENSURE(XMEMCMP(str, oid_str, strlen) == 0);
+
+   return 0;
+}

+ 1 - 0
tests/tomcrypt_test.h

@@ -44,6 +44,7 @@ int ed25519_test(void);
 int ssh_test(void);
 int ssh_test(void);
 int bcrypt_test(void);
 int bcrypt_test(void);
 int no_null_termination_check_test(void);
 int no_null_termination_check_test(void);
+int pk_oid_test(void);
 
 
 #ifdef LTC_PKCS_1
 #ifdef LTC_PKCS_1
 struct ltc_prng_descriptor* no_prng_desc_get(void);
 struct ltc_prng_descriptor* no_prng_desc_get(void);