WinDNSHelper.cpp 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. #include "WinDNSHelper.hpp"
  2. #include <WbemIdl.h>
  3. #include <comdef.h>
  4. #include <sstream>
  5. #include <string>
  6. #include <strsafe.h>
  7. #include <vector>
  8. #define MAX_KEY_LENGTH 255
  9. #define MAX_VALUE_NAME 16383
  10. namespace ZeroTier {
  11. BOOL RegDelnodeRecurse(HKEY hKeyRoot, LPTSTR lpSubKey)
  12. {
  13. LPTSTR lpEnd;
  14. LONG lResult;
  15. DWORD dwSize;
  16. TCHAR szName[MAX_PATH];
  17. HKEY hKey;
  18. FILETIME ftWrite;
  19. // First, see if we can delete the key without having
  20. // to recurse.
  21. lResult = RegDeleteKey(hKeyRoot, lpSubKey);
  22. if (lResult == ERROR_SUCCESS)
  23. return TRUE;
  24. lResult = RegOpenKeyEx(hKeyRoot, lpSubKey, 0, KEY_READ, &hKey);
  25. if (lResult != ERROR_SUCCESS) {
  26. if (lResult == ERROR_FILE_NOT_FOUND) {
  27. return TRUE;
  28. }
  29. else {
  30. return FALSE;
  31. }
  32. }
  33. // Check for an ending slash and add one if it is missing.
  34. lpEnd = lpSubKey + lstrlen(lpSubKey);
  35. if (*(lpEnd - 1) != TEXT('\\')) {
  36. *lpEnd = TEXT('\\');
  37. lpEnd++;
  38. *lpEnd = TEXT('\0');
  39. }
  40. // Enumerate the keys
  41. dwSize = MAX_PATH;
  42. lResult = RegEnumKeyEx(hKey, 0, szName, &dwSize, NULL, NULL, NULL, &ftWrite);
  43. if (lResult == ERROR_SUCCESS) {
  44. do {
  45. *lpEnd = TEXT('\0');
  46. StringCchCat(lpSubKey, MAX_PATH * 2, szName);
  47. if (! RegDelnodeRecurse(hKeyRoot, lpSubKey)) {
  48. break;
  49. }
  50. dwSize = MAX_PATH;
  51. lResult = RegEnumKeyEx(hKey, 0, szName, &dwSize, NULL, NULL, NULL, &ftWrite);
  52. } while (lResult == ERROR_SUCCESS);
  53. }
  54. lpEnd--;
  55. *lpEnd = TEXT('\0');
  56. RegCloseKey(hKey);
  57. // Try again to delete the key.
  58. lResult = RegDeleteKey(hKeyRoot, lpSubKey);
  59. if (lResult == ERROR_SUCCESS)
  60. return TRUE;
  61. return FALSE;
  62. }
  63. //*************************************************************
  64. //
  65. // RegDelnode()
  66. //
  67. // Purpose: Deletes a registry key and all its subkeys / values.
  68. //
  69. // Parameters: hKeyRoot - Root key
  70. // lpSubKey - SubKey to delete
  71. //
  72. // Return: TRUE if successful.
  73. // FALSE if an error occurs.
  74. //
  75. //*************************************************************
  76. BOOL RegDelnode(HKEY hKeyRoot, LPCTSTR lpSubKey)
  77. {
  78. TCHAR szDelKey[MAX_PATH * 2];
  79. StringCchCopy(szDelKey, MAX_PATH * 2, lpSubKey);
  80. return RegDelnodeRecurse(hKeyRoot, szDelKey);
  81. }
  82. std::vector<std::string> getSubKeys(const char* key)
  83. {
  84. std::vector<std::string> subkeys;
  85. HKEY hKey;
  86. if (RegOpenKeyExA(HKEY_LOCAL_MACHINE, key, 0, KEY_READ, &hKey) == ERROR_SUCCESS) {
  87. TCHAR achKey[MAX_KEY_LENGTH]; // buffer for subkey name
  88. DWORD cbName; // size of name string
  89. TCHAR achClass[MAX_PATH] = TEXT(""); // buffer for class name
  90. DWORD cchClassName = MAX_PATH; // size of class string
  91. DWORD cSubKeys = 0; // number of subkeys
  92. DWORD cbMaxSubKey; // longest subkey size
  93. DWORD cchMaxClass; // longest class string
  94. DWORD cValues; // number of values for key
  95. DWORD cchMaxValue; // longest value name
  96. DWORD cbMaxValueData; // longest value data
  97. DWORD cbSecurityDescriptor; // size of security descriptor
  98. FILETIME ftLastWriteTime; // last write time
  99. DWORD i, retCode;
  100. TCHAR achValue[MAX_VALUE_NAME];
  101. DWORD cchValue = MAX_VALUE_NAME;
  102. retCode = RegQueryInfoKey(
  103. hKey, // key handle
  104. achClass, // buffer for class name
  105. &cchClassName, // size of class string
  106. NULL, // reserved
  107. &cSubKeys, // number of subkeys
  108. &cbMaxSubKey, // longest subkey size
  109. &cchMaxClass, // longest class string
  110. &cValues, // number of values for this key
  111. &cchMaxValue, // longest value name
  112. &cbMaxValueData, // longest value data
  113. &cbSecurityDescriptor, // security descriptor
  114. &ftLastWriteTime); // last write time
  115. for (i = 0; i < cSubKeys; ++i) {
  116. cbName = MAX_KEY_LENGTH;
  117. retCode = RegEnumKeyEx(hKey, i, achKey, &cbName, NULL, NULL, NULL, &ftLastWriteTime);
  118. if (retCode == ERROR_SUCCESS) {
  119. subkeys.push_back(achKey);
  120. }
  121. }
  122. }
  123. RegCloseKey(hKey);
  124. return subkeys;
  125. }
  126. std::vector<std::string> getValueList(const char* key)
  127. {
  128. std::vector<std::string> values;
  129. HKEY hKey;
  130. if (RegOpenKeyExA(HKEY_LOCAL_MACHINE, key, 0, KEY_READ, &hKey) == ERROR_SUCCESS) {
  131. TCHAR achKey[MAX_KEY_LENGTH]; // buffer for subkey name
  132. DWORD cbName; // size of name string
  133. TCHAR achClass[MAX_PATH] = TEXT(""); // buffer for class name
  134. DWORD cchClassName = MAX_PATH; // size of class string
  135. DWORD cSubKeys = 0; // number of subkeys
  136. DWORD cbMaxSubKey; // longest subkey size
  137. DWORD cchMaxClass; // longest class string
  138. DWORD cValues; // number of values for key
  139. DWORD cchMaxValue; // longest value name
  140. DWORD cbMaxValueData; // longest value data
  141. DWORD cbSecurityDescriptor; // size of security descriptor
  142. FILETIME ftLastWriteTime; // last write time
  143. DWORD i, retCode;
  144. TCHAR achValue[MAX_VALUE_NAME];
  145. DWORD cchValue = MAX_VALUE_NAME;
  146. retCode = RegQueryInfoKey(
  147. hKey, // key handle
  148. achClass, // buffer for class name
  149. &cchClassName, // size of class string
  150. NULL, // reserved
  151. &cSubKeys, // number of subkeys
  152. &cbMaxSubKey, // longest subkey size
  153. &cchMaxClass, // longest class string
  154. &cValues, // number of values for this key
  155. &cchMaxValue, // longest value name
  156. &cbMaxValueData, // longest value data
  157. &cbSecurityDescriptor, // security descriptor
  158. &ftLastWriteTime); // last write time
  159. for (i = 0, retCode = ERROR_SUCCESS; i < cValues; ++i) {
  160. cchValue = MAX_VALUE_NAME;
  161. achValue[0] = '\0';
  162. retCode = RegEnumValue(hKey, i, achValue, &cchValue, NULL, NULL, NULL, NULL);
  163. if (retCode == ERROR_SUCCESS) {
  164. values.push_back(achValue);
  165. }
  166. }
  167. }
  168. RegCloseKey(hKey);
  169. return values;
  170. }
  171. std::pair<bool, std::string> WinDNSHelper::hasDNSConfig(uint64_t nwid)
  172. {
  173. char networkStr[20] = { 0 };
  174. sprintf(networkStr, "%.16llx", nwid);
  175. const char* baseKey = "SYSTEM\\CurrentControlSet\\Services\\Dnscache\\Parameters\\DnsPolicyConfig";
  176. auto subkeys = getSubKeys(baseKey);
  177. for (auto it = subkeys.begin(); it != subkeys.end(); ++it) {
  178. char sub[MAX_KEY_LENGTH] = { 0 };
  179. sprintf(sub, "%s\\%s", baseKey, it->c_str());
  180. auto dnsRecords = getValueList(sub);
  181. for (auto it2 = dnsRecords.begin(); it2 != dnsRecords.end(); ++it2) {
  182. if ((*it2) == "Comment") {
  183. HKEY hKey;
  184. if (RegOpenKeyExA(HKEY_LOCAL_MACHINE, sub, 0, KEY_READ, &hKey) == ERROR_SUCCESS) {
  185. char buf[16384] = { 0 };
  186. DWORD size = sizeof(buf);
  187. DWORD retCode = RegGetValueA(HKEY_LOCAL_MACHINE, sub, it2->c_str(), RRF_RT_REG_SZ, NULL, &buf, &size);
  188. if (retCode == ERROR_SUCCESS) {
  189. if (std::string(networkStr) == std::string(buf)) {
  190. RegCloseKey(hKey);
  191. return std::make_pair(true, std::string(sub));
  192. }
  193. }
  194. else {
  195. }
  196. }
  197. RegCloseKey(hKey);
  198. }
  199. }
  200. }
  201. return std::make_pair(false, std::string());
  202. }
  203. void WinDNSHelper::setDNS(uint64_t nwid, const char* domain, const std::vector<InetAddress>& servers)
  204. {
  205. auto hasConfig = hasDNSConfig(nwid);
  206. std::stringstream ss;
  207. for (auto it = servers.begin(); it != servers.end(); ++it) {
  208. char ipaddr[256] = { 0 };
  209. ss << it->toIpString(ipaddr);
  210. if ((it + 1) != servers.end()) {
  211. ss << ";";
  212. }
  213. }
  214. std::string serverValue = ss.str();
  215. if (hasConfig.first) {
  216. // update existing config
  217. HKEY dnsKey;
  218. if (RegOpenKeyExA(HKEY_LOCAL_MACHINE, hasConfig.second.c_str(), 0, KEY_READ | KEY_WRITE, &dnsKey) == ERROR_SUCCESS) {
  219. auto retCode = RegSetKeyValueA(dnsKey, NULL, "GenericDNSServers", REG_SZ, serverValue.data(), (DWORD)serverValue.length());
  220. if (retCode != ERROR_SUCCESS) {
  221. fprintf(stderr, "Error writing dns servers: %d\n", retCode);
  222. }
  223. }
  224. }
  225. else {
  226. // add new config
  227. const char* baseKey = "SYSTEM\\CurrentControlSet\\Services\\Dnscache\\Parameters\\DnsPolicyConfig";
  228. GUID guid;
  229. CoCreateGuid(&guid);
  230. wchar_t guidTmp[128] = { 0 };
  231. char guidStr[128] = { 0 };
  232. StringFromGUID2(guid, guidTmp, 128);
  233. wcstombs(guidStr, guidTmp, 128);
  234. char fullKey[MAX_KEY_LENGTH] = { 0 };
  235. sprintf(fullKey, "%s\\%s", baseKey, guidStr);
  236. HKEY dnsKey;
  237. RegCreateKeyA(HKEY_LOCAL_MACHINE, fullKey, &dnsKey);
  238. if (RegOpenKeyExA(HKEY_LOCAL_MACHINE, fullKey, 0, KEY_READ | KEY_WRITE, &dnsKey) == ERROR_SUCCESS) {
  239. char nwString[32] = { 0 };
  240. sprintf(nwString, "%.16llx", nwid);
  241. RegSetKeyValueA(dnsKey, NULL, "Comment", REG_SZ, nwString, strlen(nwString));
  242. DWORD configOpts = 8;
  243. RegSetKeyValueA(dnsKey, NULL, "ConfigOptions", REG_DWORD, &configOpts, sizeof(DWORD));
  244. RegSetKeyValueA(dnsKey, NULL, "DisplayName", REG_SZ, "", 0);
  245. RegSetKeyValueA(dnsKey, NULL, "GenericDNSServers", REG_SZ, serverValue.data(), serverValue.length());
  246. RegSetKeyValueA(dnsKey, NULL, "IPSECCARestriction", REG_SZ, "", 0);
  247. std::string d = "." + std::string(domain);
  248. RegSetKeyValueA(dnsKey, NULL, "Name", REG_MULTI_SZ, d.data(), d.length());
  249. DWORD version = 2;
  250. RegSetKeyValueA(dnsKey, NULL, "Version", REG_DWORD, &version, sizeof(DWORD));
  251. }
  252. }
  253. }
  254. void WinDNSHelper::removeDNS(uint64_t nwid)
  255. {
  256. auto hasConfig = hasDNSConfig(nwid);
  257. if (hasConfig.first) {
  258. RegDelnode(HKEY_LOCAL_MACHINE, hasConfig.second.c_str());
  259. }
  260. }
  261. } // namespace ZeroTier