Browse Source

Allow seek of IDecryptStream to begin (for looping encrypted audio)

rdb 6 years ago
parent
commit
a7c743fd5e

+ 20 - 0
dtool/src/prc/encryptStream.cxx

@@ -12,3 +12,23 @@
  */
  */
 
 
 #include "encryptStream.h"
 #include "encryptStream.h"
+
+/**
+ * Must be called immediately after open_read().  Decrypts the given number of
+ * bytes and checks that they match.  The amount of header bytes are added to
+ * an offset so that skipping to 0 will skip past the header.
+ *
+ * Returns true if the read magic matches the given magic, false on error.
+ */
+bool IDecryptStream::
+read_magic(const char *magic, size_t size) {
+  char this_magic[size];
+  read(this_magic, size);
+
+  if (!fail() && gcount() == size && memcmp(this_magic, magic, size) == 0) {
+    _buf.set_magic_length(size);
+    return true;
+  } else {
+    return false;
+  }
+}

+ 3 - 0
dtool/src/prc/encryptStream.h

@@ -53,6 +53,9 @@ PUBLISHED:
   MAKE_PROPERTY(key_length, get_key_length);
   MAKE_PROPERTY(key_length, get_key_length);
   MAKE_PROPERTY(iteration_count, get_iteration_count);
   MAKE_PROPERTY(iteration_count, get_iteration_count);
 
 
+public:
+  bool read_magic(const char *magic, size_t size);
+
 private:
 private:
   EncryptStreamBuf _buf;
   EncryptStreamBuf _buf;
 };
 };

+ 18 - 0
dtool/src/prc/encryptStreamBuf.I

@@ -81,3 +81,21 @@ INLINE int EncryptStreamBuf::
 get_iteration_count() const {
 get_iteration_count() const {
   return _iteration_count;
   return _iteration_count;
 }
 }
+
+/**
+ * Sets the amount of the encrypted data at the beginning that are skipped
+ * when seeking back to zero.
+ */
+INLINE void EncryptStreamBuf::
+set_magic_length(size_t length) {
+  _magic_length = length;
+}
+
+/**
+ * Sets the amount of the encrypted data at the beginning that are skipped
+ * when seeking back to zero.
+ */
+INLINE size_t EncryptStreamBuf::
+get_magic_length() const {
+  return _magic_length;
+}

+ 54 - 3
dtool/src/prc/encryptStreamBuf.cxx

@@ -177,6 +177,7 @@ open_read(std::istream *source, bool owns_source, const std::string &password) {
 
 
   _read_overflow_buffer = new unsigned char[_read_block_size];
   _read_overflow_buffer = new unsigned char[_read_block_size];
   _in_read_overflow_buffer = 0;
   _in_read_overflow_buffer = 0;
+  _finished = false;
   thread_consider_yield();
   thread_consider_yield();
 }
 }
 
 
@@ -322,6 +323,57 @@ close_write() {
   }
   }
 }
 }
 
 
+/**
+ * Implements seeking within the stream.  EncryptStreamBuf only allows seeking
+ * back to the beginning of the stream.
+ */
+std::streampos EncryptStreamBuf::
+seekoff(std::streamoff off, ios_seekdir dir, ios_openmode which) {
+  if (which != std::ios::in) {
+    // We can only do this with the input stream.
+    return -1;
+  }
+
+  if (off != 0 || dir != std::ios::beg) {
+    // We only know how to reposition to the beginning.
+    return -1;
+  }
+
+  size_t n = egptr() - gptr();
+  gbump(n);
+
+  if (_source->rdbuf()->pubseekpos(0, std::ios::in) == (std::streampos)0) {
+    int result = EVP_DecryptInit(_read_ctx, nullptr, nullptr, nullptr);
+    nassertr_always(result > 0, -1);
+
+    _source->clear();
+    _in_read_overflow_buffer = 0;
+    _finished = false;
+
+    // Skip past the header.
+    int iv_length = EVP_CIPHER_CTX_iv_length(_read_ctx);
+    _source->ignore(6 + iv_length);
+
+    // Ignore the magic bytes.
+    size_t magic_length = get_magic_length();
+    char *buffer = (char *)alloca(magic_length);
+    if (read_chars(buffer, magic_length) == magic_length) {
+      return 0;
+    }
+  }
+
+  return -1;
+}
+
+/**
+ * Implements seeking within the stream.  EncryptStreamBuf only allows seeking
+ * back to the beginning of the stream.
+ */
+std::streampos EncryptStreamBuf::
+seekpos(std::streampos pos, ios_openmode which) {
+  return seekoff(pos, std::ios::beg, which);
+}
+
 /**
 /**
  * Called by the system ostream implementation when its internal buffer is
  * Called by the system ostream implementation when its internal buffer is
  * filled, plus one character.
  * filled, plus one character.
@@ -423,7 +475,7 @@ read_chars(char *start, size_t length) {
 
 
   do {
   do {
     // Get more bytes from the stream.
     // Get more bytes from the stream.
-    if (_read_ctx == nullptr) {
+    if (_read_ctx == nullptr || _finished) {
       return 0;
       return 0;
     }
     }
 
 
@@ -439,8 +491,7 @@ read_chars(char *start, size_t length) {
     } else {
     } else {
       result =
       result =
         EVP_DecryptFinal(_read_ctx, read_buffer, &bytes_read);
         EVP_DecryptFinal(_read_ctx, read_buffer, &bytes_read);
-      EVP_CIPHER_CTX_free(_read_ctx);
-      _read_ctx = nullptr;
+      _finished = true;
     }
     }
 
 
     if (result <= 0) {
     if (result <= 0) {

+ 9 - 0
dtool/src/prc/encryptStreamBuf.h

@@ -44,6 +44,12 @@ public:
   INLINE void set_iteration_count(int iteration_count);
   INLINE void set_iteration_count(int iteration_count);
   INLINE int get_iteration_count() const;
   INLINE int get_iteration_count() const;
 
 
+  INLINE void set_magic_length(size_t length);
+  INLINE size_t get_magic_length() const;
+
+  virtual std::streampos seekoff(std::streamoff off, ios_seekdir dir, ios_openmode which);
+  virtual std::streampos seekpos(std::streampos pos, ios_openmode which);
+
 protected:
 protected:
   virtual int overflow(int c);
   virtual int overflow(int c);
   virtual int sync();
   virtual int sync();
@@ -71,6 +77,9 @@ private:
 
 
   EVP_CIPHER_CTX *_write_ctx;
   EVP_CIPHER_CTX *_write_ctx;
   size_t _write_block_size;
   size_t _write_block_size;
+
+  size_t _magic_length = 0;
+  bool _finished = false;
 };
 };
 
 
 #include "encryptStreamBuf.I"
 #include "encryptStreamBuf.I"

+ 1 - 4
panda/src/express/multifile.cxx

@@ -2068,10 +2068,7 @@ open_read_subfile(Subfile *subfile) {
     stream = wrapper;
     stream = wrapper;
 
 
     // Validate the password by confirming that the encryption header matches.
     // Validate the password by confirming that the encryption header matches.
-    char this_header[_encrypt_header_size];
-    stream->read(this_header, _encrypt_header_size);
-    if (stream->fail() || stream->gcount() != (unsigned)_encrypt_header_size ||
-        memcmp(this_header, _encrypt_header, _encrypt_header_size) != 0) {
+    if (!wrapper->read_magic(_encrypt_header, _encrypt_header_size)) {
       express_cat.error()
       express_cat.error()
         << "Unable to decrypt subfile " << subfile->_name << ".\n";
         << "Unable to decrypt subfile " << subfile->_name << ".\n";
       delete stream;
       delete stream;

+ 19 - 0
tests/prc/test_encrypt_stream.py

@@ -0,0 +1,19 @@
+from panda3d import core
+
+import pytest
+
+
[email protected](not hasattr(core, 'IDecryptStream'), reason="Requires OpenSSL")
+def test_decrypt_stream():
+    encrypted = b'[\x00\x10\x00d\x00\x07K\x08\x03\xabS\x13L\xab\x93\x1b\x15\xe4\xeel\x80u o\xd0\x80aY_]\x10\x8a\xb5\xff\x9d1\xc9\xd3\xac\x95\x04\xd8\xdf\x10\xa1'
+    decrypted = b'abcdefghijklmnopqrstuvwxyz'
+
+    ss = core.StringStream(encrypted)
+    ds = core.IDecryptStream(ss, False, '0123456789')
+
+    assert ds.read(len(decrypted)) == decrypted
+    assert ds.readall() == b''
+
+    # Allow seeking back to the beginning
+    ds.seekg(0)
+    assert ds.readall() == decrypted