Browse Source

Improve MP3 detection to reduce false positive.

Miku AuahDark 3 years ago
parent
commit
c9ab08a2d3
2 changed files with 110 additions and 11 deletions
  1. 106 11
      src/modules/sound/lullaby/MP3Decoder.cpp
  2. 4 0
      src/modules/sound/lullaby/MP3Decoder.h

+ 106 - 11
src/modules/sound/lullaby/MP3Decoder.cpp

@@ -32,23 +32,119 @@ namespace lullaby
 
 static size_t onRead(void *pUserData, void *pBufferOut, size_t bytesToRead)
 {
-	auto stream = (Stream *) pUserData;
-	int64 read = stream->read(pBufferOut, bytesToRead);
+	auto decoder = (MP3Decoder *) pUserData;
+	int64 read = decoder->stream->read(pBufferOut, bytesToRead);
 	return std::max<int64>(0, read);
 }
 
 static drmp3_bool32 onSeek(void *pUserData, int offset, drmp3_seek_origin origin)
 {
-	auto stream = (Stream *) pUserData;
-	auto seekorigin = origin == drmp3_seek_origin_current ? Stream::SEEKORIGIN_CURRENT : Stream::SEEKORIGIN_BEGIN;
-	return stream->seek(offset, seekorigin) ? DRMP3_TRUE : DRMP3_FALSE;
+	auto decoder = (MP3Decoder *) pUserData;
+	int64 pos = decoder->offset;
+
+	// Due to possible offsets, we have to calculate the position ourself.
+	switch (origin)
+	{
+	case drmp3_seek_origin_start:
+		pos += offset;
+		break;
+	case drmp3_seek_origin_current:
+		pos = decoder->stream->tell() + offset;
+		break;
+	default:
+		return false;
+	}
+
+	if (pos < decoder->offset)
+		return false;
+
+	return decoder->stream->seek(pos, Stream::SEEKORIGIN_BEGIN) ? DRMP3_TRUE : DRMP3_FALSE;
+}
+
+// Copied from dr_mp3 function drmp3_hdr_valid()
+static bool isMP3HeaderValid(const uint8 *h)
+{
+	return
+		// Sync bits
+		h[0] == 0xff &&
+		((h[1] & 0xF0) == 0xf0 || (h[1] & 0xFE) == 0xe2) &&
+		// Check layer
+		(DRMP3_HDR_GET_LAYER(h) != 0) &&
+		// Check bitrate
+		(DRMP3_HDR_GET_BITRATE(h) != 15) &&
+		// Check sample rate
+		(DRMP3_HDR_GET_SAMPLE_RATE(h) != 3);
+}
+
+static int64 findFirstValidHeader(Stream* stream)
+{
+	constexpr size_t LOOKUP_SIZE = 16384;
+
+	std::vector<uint8> data(LOOKUP_SIZE);
+	uint8 header[10];
+	uint8 *dataPtr = data.data();
+	int64 buffer = 0;
+	int64 offset = 0;
+
+	if (stream->read(header, 10) < 10)
+		return -1;
+
+	if (memcmp(header, "TAG", 3) == 0)
+	{
+		// ID3v1 tag is always 128 bytes long
+		if (!stream->seek(128, Stream::SEEKORIGIN_BEGIN))
+			return -1;
+
+		buffer = stream->read(dataPtr, LOOKUP_SIZE);
+		offset = 128;
+	}
+	else if (memcmp(header, "ID3", 3) == 0)
+	{
+		// ID3v2 tag header is 10 bytes long, but we're
+		// only interested on how much we should skip.
+		int64 off =
+			header[9] |
+			((int64) header[8] << 7) |
+			((int64) header[7] << 14) |
+			((int64) header[6] << 21);
+
+		if (!stream->seek(off, Stream::SEEKORIGIN_CURRENT))
+			return -1;
+
+		buffer = stream->read(dataPtr, LOOKUP_SIZE);
+		offset = off + 10;
+	}
+	else
+	{
+		// Copy the rest to data buffer
+		memcpy(dataPtr, header, 10);
+		buffer = 10 + stream->read(dataPtr + 10, LOOKUP_SIZE - 10);
+	}
+
+	// Look for mp3 data
+	for (int i = 0; i < buffer - 4; i++, offset++)
+	{
+		if (isMP3HeaderValid(dataPtr++))
+		{
+			stream->seek(offset, Stream::SEEKORIGIN_BEGIN);
+			return offset;
+		}
+	}
+
+	// No valid MP3 frame found in first 16KB data
+	return -1;
 }
 
 MP3Decoder::MP3Decoder(Stream *stream, int bufferSize)
-	: Decoder(stream, bufferSize)
+: Decoder(stream, bufferSize)
 {
+	// Check for possible ID3 tag and skip it if necessary.
+	offset = findFirstValidHeader(stream);
+	if (offset == -1)
+		throw love::Exception("Could not find first valid mp3 header.");
+
 	// initialize mp3 handle
-	if (!drmp3_init(&mp3, onRead, onSeek, stream, nullptr, nullptr))
+	if (!drmp3_init(&mp3, onRead, onSeek, this, nullptr, nullptr))
 		throw love::Exception("Could not read mp3 data.");
 
 	sampleRate = mp3.sampleRate;
@@ -63,14 +159,13 @@ MP3Decoder::MP3Decoder(Stream *stream, int bufferSize)
 	duration = ((double) pcmCount) / ((double) mp3.sampleRate);
 
 	// create seek table
-	uint32_t mp3FrameInt = mp3FrameCount;
-	seekTable.resize(mp3FrameCount, {0ULL, 0ULL, 0, 0});
+	drmp3_uint32 mp3FrameInt = (drmp3_uint32) mp3FrameCount;
+	seekTable.resize((size_t) mp3FrameCount, {0ULL, 0ULL, 0, 0});
 	if (!drmp3_calculate_seek_points(&mp3, &mp3FrameInt, seekTable.data()))
 	{
 		drmp3_uninit(&mp3);
 		throw love::Exception("Could not calculate mp3 seek table");
 	}
-	mp3FrameInt = mp3FrameInt > mp3FrameCount ? mp3FrameCount : mp3FrameInt;
 
 	// bind seek table
 	if (!drmp3_bind_seek_table(&mp3, mp3FrameInt, seekTable.data()))
@@ -105,7 +200,7 @@ int MP3Decoder::decode()
 
 bool MP3Decoder::seek(double s)
 {
-	drmp3_uint64 targetSample = s * mp3.sampleRate;
+	drmp3_uint64 targetSample = (drmp3_uint64) (s * mp3.sampleRate);
 	drmp3_bool32 success = drmp3_seek_to_pcm_frame(&mp3, targetSample);
 
 	if (success)

+ 4 - 0
src/modules/sound/lullaby/MP3Decoder.h

@@ -54,11 +54,15 @@ public:
 	double getDuration() override;
 
 private:
+	friend size_t onRead(void *pUserData, void *pBufferOut, size_t bytesToRead);
+	friend drmp3_bool32 onSeek(void *pUserData, int offset, drmp3_seek_origin origin);
 
 	// MP3 handle
 	drmp3 mp3;
 	// Used for fast seeking
 	std::vector<drmp3_seek_point> seekTable;
+	// Position of first MP3 frame found
+	int64 offset;
 
 	double duration;
 }; // MP3Decoder