Browse Source

rsa_test: improve a bit

Steffen Jaeckel 11 years ago
parent
commit
f86d36c676
1 changed files with 31 additions and 25 deletions
  1. 31 25
      testprof/rsa_test.c

+ 31 - 25
testprof/rsa_test.c

@@ -100,6 +100,19 @@ static int rsa_compat_test(void)
    return 0;
 }
 
+static void _rsa_testPrint(const char* what, const unsigned char* p, const unsigned long l)
+{
+  unsigned long x;
+  fprintf(stderr, "%s contents: \n", what);
+  for (x = 0; x < l; ) {
+      fprintf(stderr, "%02x ", p[x]);
+      if (!(++x % 16)) {
+         fprintf(stderr, "\n");
+      }
+  }
+  fprintf(stderr, "\n");
+}
+
 int rsa_test(void)
 {
    unsigned char in[1024], out[1024], tmp[1024];
@@ -186,24 +199,9 @@ for (cnt = 0; cnt < len; ) {
          return 1;
       }
       if (len2 != rsa_msgsize || memcmp(tmp, in, rsa_msgsize)) {
-         unsigned long x;
          fprintf(stderr, "\nrsa_decrypt_key mismatch, len %lu (second decrypt)\n", len2);
-         fprintf(stderr, "Original contents: \n");
-         for (x = 0; x < rsa_msgsize; ) {
-             fprintf(stderr, "%02x ", in[x]);
-             if (!(++x % 16)) {
-                fprintf(stderr, "\n");
-             }
-         }
-         fprintf(stderr, "\n");
-         fprintf(stderr, "Output contents: \n");
-         for (x = 0; x < rsa_msgsize; ) {
-             fprintf(stderr, "%02x ", out[x]);
-             if (!(++x % 16)) {
-                fprintf(stderr, "\n");
-             }
-         }
-         fprintf(stderr, "\n");
+         _rsa_testPrint("Original", in, rsa_msgsize);
+         _rsa_testPrint("Output", tmp, len2);
          return 1;
       }
    }
@@ -232,6 +230,8 @@ for (cnt = 0; cnt < len; ) {
       }
       if (len2 != rsa_msgsize || memcmp(tmp, in, rsa_msgsize)) {
          fprintf(stderr, "rsa_decrypt_key mismatch len %lu", len2);
+         _rsa_testPrint("Original", in, rsa_msgsize);
+         _rsa_testPrint("Output", tmp, len2);
          return 1;
       }
    }
@@ -250,10 +250,16 @@ for (cnt = 0; cnt < len; ) {
          fprintf(stderr, "rsa_decrypt_key_ex failed, %d, %d", stat, stat2);
          return 1;
       }
-      if (len2 != rsa_msgsize || memcmp(tmp, in, rsa_msgsize)) {
+      if (len2 != rsa_msgsize) {
          fprintf(stderr, "rsa_decrypt_key_ex mismatch len %lu", len2);
          return 1;
       }
+      if (memcmp(tmp, in, rsa_msgsize)) {
+         fprintf(stderr, "rsa_decrypt_key_ex mismatch data");
+         _rsa_testPrint("Original", in, rsa_msgsize);
+         _rsa_testPrint("Output", tmp, rsa_msgsize);
+         return 1;
+      }
    }
 
    /* sign a message (unsalted, lower cholestorol and Atkins approved) now */
@@ -354,10 +360,12 @@ for (cnt = 0; cnt < len; ) {
     * (4) Forge the structure of PKCS#1-EMSA encoded data
     * (4.1) Search for start and end of the padding string
     * (4.2) Move the signature to the front of the padding string
-    * (4.3) Fill the message until the end with random data
+    * (4.3) Zero the message until the end
     * (5) Encrypt the package again
     * (6) Profit :)
-    *     Verification process should succeed, but result should not be valid
+    *     For PS lengths < 8:  the verification process should fail
+    *     For PS lengths >= 8: the verification process should succeed
+    *     For all PS lengths:  the result should not be valid
     */
 
    unsigned char* p = in;
@@ -380,7 +388,7 @@ for (cnt = 0; cnt < len; ) {
      printf("\nBefore:");
      for (cnt = 0; cnt < len3; ++cnt) {
        if (cnt%32 == 0)
-         printf("\n%3d:", cnt);
+         printf("\n%3lu:", cnt);
        printf(" %02x", p3[cnt]);
      }
 #endif
@@ -397,15 +405,13 @@ for (cnt = 0; cnt < len; ) {
      memmove(&p3[cnt+i], &p3[cnt2], len3-cnt2);
      /* (4.3) */
      for (cnt = cnt + len3-cnt2+i; cnt < len; ++cnt) {
-        do {
-            p3[cnt] = (unsigned char)rand();
-        } while (p3[cnt] == 0);
+        p3[cnt] = 0;
      }
 #if 0
      printf("\nAfter:");
      for (cnt = 0; cnt < len3; ++cnt) {
        if (cnt%32 == 0)
-         printf("\n%3d:", cnt);
+         printf("\n%3lu:", cnt);
        printf(" %02x", p3[cnt]);
      }
      printf("\n");