Browse Source

Merge pull request #757 from paullouisageneau/utils-random

Refactor random number generation
Paul-Louis Ageneau 2 years ago
parent
commit
1ec2a2f519

+ 5 - 3
examples/client-benchmark/main.cpp

@@ -479,11 +479,13 @@ shared_ptr<rtc::PeerConnection> createPeerConnection(const rtc::Configuration &c
 
 
 // Helper function to generate a random ID
 // Helper function to generate a random ID
 std::string randomId(size_t length) {
 std::string randomId(size_t length) {
+	using std::chrono::system_clock;
+	static thread_local std::mt19937 rng(
+	    static_cast<unsigned int>(system_clock::now().time_since_epoch().count()));
 	static const std::string characters(
 	static const std::string characters(
 	    "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
 	    "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
 	std::string id(length, '0');
 	std::string id(length, '0');
-	std::default_random_engine rng(std::random_device{}());
-	std::uniform_int_distribution<int> dist(0, int(characters.size() - 1));
-	std::generate(id.begin(), id.end(), [&]() { return characters.at(dist(rng)); });
+	std::uniform_int_distribution<int> uniform(0, int(characters.size() - 1));
+	std::generate(id.begin(), id.end(), [&]() { return characters.at(uniform(rng)); });
 	return id;
 	return id;
 }
 }

+ 6 - 3
examples/client/main.cpp

@@ -28,6 +28,7 @@
 #include <nlohmann/json.hpp>
 #include <nlohmann/json.hpp>
 
 
 #include <algorithm>
 #include <algorithm>
+#include <chrono>
 #include <future>
 #include <future>
 #include <iostream>
 #include <iostream>
 #include <memory>
 #include <memory>
@@ -263,11 +264,13 @@ shared_ptr<rtc::PeerConnection> createPeerConnection(const rtc::Configuration &c
 
 
 // Helper function to generate a random ID
 // Helper function to generate a random ID
 std::string randomId(size_t length) {
 std::string randomId(size_t length) {
+	using std::chrono::system_clock;
+	static thread_local std::mt19937 rng(
+	    static_cast<unsigned int>(system_clock::now().time_since_epoch().count()));
 	static const std::string characters(
 	static const std::string characters(
 	    "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
 	    "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
 	std::string id(length, '0');
 	std::string id(length, '0');
-	std::default_random_engine rng(std::random_device{}());
-	std::uniform_int_distribution<int> dist(0, int(characters.size() - 1));
-	std::generate(id.begin(), id.end(), [&]() { return characters.at(dist(rng)); });
+	std::uniform_int_distribution<int> uniform(0, int(characters.size() - 1));
+	std::generate(id.begin(), id.end(), [&]() { return characters.at(uniform(rng)); });
 	return id;
 	return id;
 }
 }

+ 2 - 4
src/description.cpp

@@ -162,10 +162,8 @@ Description::Description(const string &sdp, Type type, Role role)
 		mUsername = "rtc";
 		mUsername = "rtc";
 
 
 	if (mSessionId.empty()) {
 	if (mSessionId.empty()) {
-		auto seed = static_cast<unsigned int>(system_clock::now().time_since_epoch().count());
-		std::default_random_engine generator(seed);
-		std::uniform_int_distribution<uint32_t> uniform;
-		mSessionId = std::to_string(uniform(generator));
+		auto uniform = std::bind(std::uniform_int_distribution<uint32_t>(), utils::random_engine());
+		mSessionId = std::to_string(uniform());
 	}
 	}
 }
 }
 
 

+ 3 - 4
src/impl/icetransport.cpp

@@ -20,6 +20,7 @@
 #include "configuration.hpp"
 #include "configuration.hpp"
 #include "internals.hpp"
 #include "internals.hpp"
 #include "transport.hpp"
 #include "transport.hpp"
+#include "utils.hpp"
 
 
 #include <iostream>
 #include <iostream>
 #include <random>
 #include <random>
@@ -103,8 +104,7 @@ IceTransport::IceTransport(const Configuration &config, candidate_callback candi
 
 
 	// Randomize servers order
 	// Randomize servers order
 	std::vector<IceServer> servers = config.iceServers;
 	std::vector<IceServer> servers = config.iceServers;
-	auto seed = static_cast<unsigned int>(system_clock::now().time_since_epoch().count());
-	std::shuffle(servers.begin(), servers.end(), std::default_random_engine(seed));
+	std::shuffle(servers.begin(), servers.end(), utils::random_engine());
 
 
 	// Pick a STUN server
 	// Pick a STUN server
 	for (auto &server : servers) {
 	for (auto &server : servers) {
@@ -451,8 +451,7 @@ IceTransport::IceTransport(const Configuration &config, candidate_callback candi
 
 
 	// Randomize order
 	// Randomize order
 	std::vector<IceServer> servers = config.iceServers;
 	std::vector<IceServer> servers = config.iceServers;
-	auto seed = static_cast<unsigned int>(system_clock::now().time_since_epoch().count());
-	std::shuffle(servers.begin(), servers.end(), std::default_random_engine(seed));
+	std::shuffle(servers.begin(), servers.end(), utils::random_engine());
 
 
 	// Add one STUN server
 	// Add one STUN server
 	bool success = false;
 	bool success = false;

+ 20 - 0
src/impl/utils.cpp

@@ -21,6 +21,7 @@
 #include "impl/internals.hpp"
 #include "impl/internals.hpp"
 
 
 #include <cctype>
 #include <cctype>
+#include <chrono>
 #include <functional>
 #include <functional>
 #include <iterator>
 #include <iterator>
 #include <sstream>
 #include <sstream>
@@ -110,4 +111,23 @@ string base64_encode(const binary &data) {
 	return out;
 	return out;
 }
 }
 
 
+std::seed_seq random_seed() {
+	std::vector<unsigned int> seed;
+
+	// Seed with random device
+	try {
+		// On some systems an exception might be thrown if the random_device can't be initialized
+		std::random_device device;
+		seed.push_back(device());
+	} catch (const std::exception &) {
+		// Ignore
+	}
+
+	// Seed with current time
+	using std::chrono::system_clock;
+	seed.push_back(static_cast<unsigned int>(system_clock::now().time_since_epoch().count()));
+
+	return std::seed_seq(seed.begin(), seed.end());
+}
+
 } // namespace rtc::impl::utils
 } // namespace rtc::impl::utils

+ 34 - 1
src/impl/utils.hpp

@@ -21,6 +21,9 @@
 
 
 #include "common.hpp"
 #include "common.hpp"
 
 
+#include <climits>
+#include <limits>
+#include <random>
 #include <vector>
 #include <vector>
 
 
 namespace rtc::impl::utils {
 namespace rtc::impl::utils {
@@ -36,6 +39,36 @@ string url_decode(const string &str);
 // See https://www.rfc-editor.org/rfc/rfc4648.html#section-4
 // See https://www.rfc-editor.org/rfc/rfc4648.html#section-4
 string base64_encode(const binary &data);
 string base64_encode(const binary &data);
 
 
-} // namespace rtc::impl
+// Return a random seed sequence
+std::seed_seq random_seed();
+
+template <typename Generator, typename Result = typename Generator::result_type>
+struct random_engine_wrapper {
+	Generator &engine;
+	using result_type = Result;
+	static constexpr result_type min() { return static_cast<Result>(Generator::min()); }
+	static constexpr result_type max() { return static_cast<Result>(Generator::max()); }
+	inline result_type operator()() { return static_cast<Result>(engine()); }
+	inline void discard(unsigned long long z) { engine.discard(z); }
+};
+
+// Return a wrapped thread-local seeded random number generator
+template <typename Generator = std::mt19937, typename Result = typename Generator::result_type>
+auto random_engine() {
+	static thread_local std::seed_seq seed = random_seed();
+	static thread_local Generator engine{seed};
+	return random_engine_wrapper<Generator, Result>{engine};
+}
+
+// Return a wrapped thread-local seeded random bytes generator
+template <typename Generator = std::mt19937> auto random_bytes_engine() {
+	using char_independent_bits_engine =
+	    std::independent_bits_engine<Generator, CHAR_BIT, unsigned short>;
+	static_assert(char_independent_bits_engine::min() == std::numeric_limits<uint8_t>::min());
+	static_assert(char_independent_bits_engine::max() == std::numeric_limits<uint8_t>::max());
+	return random_engine<char_independent_bits_engine, uint8_t>();
+}
+
+} // namespace rtc::impl::utils
 
 
 #endif
 #endif

+ 1 - 5
src/impl/wshandshake.cpp

@@ -36,8 +36,6 @@ namespace rtc::impl {
 
 
 using std::to_string;
 using std::to_string;
 using std::chrono::system_clock;
 using std::chrono::system_clock;
-using random_bytes_engine =
-    std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned short>;
 
 
 WsHandshake::WsHandshake() {}
 WsHandshake::WsHandshake() {}
 
 
@@ -240,11 +238,9 @@ string WsHandshake::generateKey() {
 	// RFC 6455: The request MUST include a header field with the name Sec-WebSocket-Key.  The value
 	// RFC 6455: The request MUST include a header field with the name Sec-WebSocket-Key.  The value
 	// of this header field MUST be a nonce consisting of a randomly selected 16-byte value that has
 	// of this header field MUST be a nonce consisting of a randomly selected 16-byte value that has
 	// been base64-encoded. [...] The nonce MUST be selected randomly for each connection.
 	// been base64-encoded. [...] The nonce MUST be selected randomly for each connection.
-	auto seed = static_cast<unsigned int>(system_clock::now().time_since_epoch().count());
-	random_bytes_engine generator(seed);
 	binary key(16);
 	binary key(16);
 	auto k = reinterpret_cast<uint8_t *>(key.data());
 	auto k = reinterpret_cast<uint8_t *>(key.data());
-	std::generate(k, k + key.size(), [&]() { return uint8_t(generator()); });
+	std::generate(k, k + key.size(), utils::random_bytes_engine());
 	return utils::base64_encode(key);
 	return utils::base64_encode(key);
 }
 }
 
 

+ 2 - 6
src/impl/wstransport.cpp

@@ -20,6 +20,7 @@
 #include "tcptransport.hpp"
 #include "tcptransport.hpp"
 #include "threadpool.hpp"
 #include "threadpool.hpp"
 #include "tlstransport.hpp"
 #include "tlstransport.hpp"
+#include "utils.hpp"
 
 
 #if RTC_ENABLE_WEBSOCKET
 #if RTC_ENABLE_WEBSOCKET
 
 
@@ -50,8 +51,6 @@ namespace rtc::impl {
 using std::to_integer;
 using std::to_integer;
 using std::to_string;
 using std::to_string;
 using std::chrono::system_clock;
 using std::chrono::system_clock;
-using random_bytes_engine =
-    std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned short>;
 
 
 WsTransport::WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower,
 WsTransport::WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower,
                          shared_ptr<WsHandshake> handshake, int maxOutstandingPings,
                          shared_ptr<WsHandshake> handshake, int maxOutstandingPings,
@@ -366,13 +365,10 @@ bool WsTransport::sendFrame(const Frame &frame) {
 	}
 	}
 
 
 	if (frame.mask) {
 	if (frame.mask) {
-		auto seed = static_cast<unsigned int>(system_clock::now().time_since_epoch().count());
-		random_bytes_engine generator(seed);
-
 		byte *maskingKey = reinterpret_cast<byte *>(cur);
 		byte *maskingKey = reinterpret_cast<byte *>(cur);
 
 
 		auto u = reinterpret_cast<uint8_t *>(maskingKey);
 		auto u = reinterpret_cast<uint8_t *>(maskingKey);
-		std::generate(u, u + 4, [&]() { return uint8_t(generator()); });
+		std::generate(u, u + 4, utils::random_bytes_engine());
 		cur += 4;
 		cur += 4;
 
 
 		for (size_t i = 0; i < frame.length; ++i)
 		for (size_t i = 0; i < frame.length; ++i)

+ 7 - 4
src/rtppacketizationconfig.cpp

@@ -20,12 +20,16 @@
 
 
 #include "rtppacketizationconfig.hpp"
 #include "rtppacketizationconfig.hpp"
 
 
+#include "impl/utils.hpp"
+
 #include <cassert>
 #include <cassert>
 #include <limits>
 #include <limits>
 #include <random>
 #include <random>
 
 
 namespace rtc {
 namespace rtc {
 
 
+namespace utils = impl::utils;
+
 RtpPacketizationConfig::RtpPacketizationConfig(SSRC ssrc, string cname, uint8_t payloadType,
 RtpPacketizationConfig::RtpPacketizationConfig(SSRC ssrc, string cname, uint8_t payloadType,
                                                uint32_t clockRate, uint8_t videoOrientationId)
                                                uint32_t clockRate, uint8_t videoOrientationId)
     : ssrc(ssrc), cname(cname), payloadType(payloadType), clockRate(clockRate),
     : ssrc(ssrc), cname(cname), payloadType(payloadType), clockRate(clockRate),
@@ -35,10 +39,9 @@ RtpPacketizationConfig::RtpPacketizationConfig(SSRC ssrc, string cname, uint8_t
 	// RFC 3550: The initial value of the sequence number SHOULD be random (unpredictable) to make
 	// RFC 3550: The initial value of the sequence number SHOULD be random (unpredictable) to make
 	// known-plaintext attacks on encryption more difficult [...] The initial value of the timestamp
 	// known-plaintext attacks on encryption more difficult [...] The initial value of the timestamp
 	// SHOULD be random, as for the sequence number.
 	// SHOULD be random, as for the sequence number.
-	std::default_random_engine rng(std::random_device{}());
-	std::uniform_int_distribution<uint32_t> dist(0, std::numeric_limits<uint32_t>::max());
-	sequenceNumber = static_cast<uint16_t>(dist(rng));
-	timestamp = startTimestamp = dist(rng);
+	auto uniform = std::bind(std::uniform_int_distribution<uint32_t>(), utils::random_engine());
+	sequenceNumber = static_cast<uint16_t>(uniform());
+	timestamp = startTimestamp = uniform();
 }
 }
 
 
 double RtpPacketizationConfig::getSecondsFromTimestamp(uint32_t timestamp, uint32_t clockRate) {
 double RtpPacketizationConfig::getSecondsFromTimestamp(uint32_t timestamp, uint32_t clockRate) {