windows_registry_setting.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. """
  2. All or portions of this file Copyright (c) Amazon.com, Inc. or its affiliates or
  3. its licensors.
  4. For complete copyright and license terms please see the LICENSE at the root of this
  5. distribution (the "License"). All use of this software is governed by the License,
  6. or, if provided, by the license below or the license accompanying this file. Do not
  7. remove or modify any license notices. This file is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. Class for querying and setting a windows registry setting.
  10. """
  11. import pytest
  12. import logging
  13. from typing import List, Optional, Tuple, Any
  14. from winreg import (
  15. CreateKey,
  16. OpenKey,
  17. QueryValueEx,
  18. DeleteValue,
  19. SetValueEx,
  20. KEY_ALL_ACCESS,
  21. KEY_WRITE,
  22. REG_SZ,
  23. REG_MULTI_SZ,
  24. REG_DWORD,
  25. HKEY_CURRENT_USER,
  26. )
  27. from .platform_setting import PlatformSetting
  28. logger = logging.getLogger(__name__)
  29. class WindowsRegistrySetting(PlatformSetting):
  30. def __init__(self, workspace: pytest.fixture, subkey: str, key: str, hive: Optional[str] = None) -> None:
  31. super().__init__(workspace, subkey, key)
  32. self._hive = None
  33. try:
  34. if hive is not None:
  35. self._hive = self._str_to_hive(hive)
  36. except ValueError:
  37. logger.warning(f"Windows Registry Hive {hive} not recognized, using default: HKEY_CURRENT_USER")
  38. finally:
  39. if self._hive is None:
  40. self._hive = HKEY_CURRENT_USER
  41. def get_value(self, get_type: Optional[bool] = False) -> Any:
  42. """Retrieves the fast scan value in Windows registry (and optionally the type). If entry DNE, returns None."""
  43. if self.entry_exists():
  44. registryKey = OpenKey(self._hive, self._key)
  45. value = QueryValueEx(registryKey, self._subkey)
  46. registryKey.Close()
  47. # Convert windows data type to universal data type flag: PlatformSettings.DATA_TYPE
  48. # And handles unicode conversion for strings
  49. value = self._convert_value(value)
  50. return value if get_type else value[0]
  51. else:
  52. logger.warning(f"Could not retrieve Registry entry; key: {self._key}, subkey: {self._subkey}.")
  53. return None
  54. def set_value(self, value: Any) -> bool:
  55. """Sets the Windows registry value."""
  56. value, win_type = self._format_data(value)
  57. registryKey = None
  58. result = False
  59. try:
  60. CreateKey(self._hive, self._subkey)
  61. registryKey = OpenKey(self._hive, self._key, 0, KEY_WRITE)
  62. SetValueEx(registryKey, self._subkey, 0, win_type, value)
  63. result = True
  64. except WindowsError as e:
  65. logger.warning(f"Windows error caught while setting fast scan registry: {e}")
  66. finally:
  67. if registryKey is not None:
  68. # Close key if it's been opened successfully
  69. registryKey.Close()
  70. return result
  71. def delete_entry(self) -> bool:
  72. """Deletes the Windows registry entry for fast scan enabled"""
  73. try:
  74. if self.entry_exists():
  75. registryKey = OpenKey(self._hive, self._key, 0, KEY_ALL_ACCESS)
  76. DeleteValue(registryKey, self._subkey)
  77. registryKey.Close()
  78. return True
  79. except WindowsError:
  80. logger.error(f"Could not delete registry entry; key: {self._key}, subkey: {self._subkey}")
  81. finally:
  82. return False
  83. def entry_exists(self) -> bool:
  84. """Checks for existence of the setting in Windows registry."""
  85. try:
  86. # Attempt to open and query key. If fails then the entry DNE
  87. registryKey = OpenKey(self._hive, self._key)
  88. QueryValueEx(registryKey, self._subkey)
  89. registryKey.Close()
  90. return True
  91. except WindowsError:
  92. return False
  93. @staticmethod
  94. def _format_data(value: bool or int or str or List[str]) -> Tuple[int or str or List[str], int]:
  95. """Formats the type of the value provided. Returns the formatted value and the windows registry type (int)."""
  96. if type(value) == str:
  97. return value, REG_SZ
  98. elif type(value) == bool:
  99. value = "true" if value else "false"
  100. return value, REG_SZ
  101. elif type(value) == int or type(value) == float:
  102. if type(value) == float:
  103. logger.warning(f"Windows registry does not support floats. Truncating {value} to integer")
  104. value = int(value)
  105. return value, REG_DWORD
  106. elif type(value) == list:
  107. for single_value in value:
  108. if type(single_value) != str:
  109. # fmt:off
  110. raise ValueError(
  111. f"Windows Registry lists only support strings, got a {type(single_value)} in the list")
  112. # fmt:on
  113. return value, REG_MULTI_SZ
  114. else:
  115. raise ValueError(f"Windows registry expected types: int, str and [str], found {type(value)}")
  116. @staticmethod
  117. def _convert_value(value_tuple: Tuple[Any, int]) -> Tuple[Any, PlatformSetting.DATA_TYPE]:
  118. """Converts the Windows registry data and type (tuple) to a (standardized) data and PlatformSetting.DATA_TYPE"""
  119. value, windows_type = value_tuple
  120. if windows_type == REG_SZ:
  121. # Convert from unicode to string
  122. return value, PlatformSetting.DATA_TYPE.STR
  123. elif windows_type == REG_MULTI_SZ:
  124. # Convert from unicode to string
  125. return [string for string in value], PlatformSetting.DATA_TYPE.STR_LIST
  126. elif windows_type == REG_DWORD:
  127. return value, PlatformSetting.DATA_TYPE.INT
  128. else:
  129. raise ValueError(f"Type flag not recognized: {windows_type}")
  130. @staticmethod
  131. def _str_to_hive(hive_str: str) -> int:
  132. """Converts a string to a Windows Registry Hive enum (int)"""
  133. from winreg import HKEY_CLASSES_ROOT, HKEY_CURRENT_CONFIG, HKEY_LOCAL_MACHINE, HKEY_USERS
  134. lower = hive_str.lower()
  135. if lower == "hkey_current_user" or lower == "current_user":
  136. return HKEY_CURRENT_USER
  137. elif lower == "hkey_classes_root" or lower == "classes_root":
  138. return HKEY_CLASSES_ROOT
  139. elif lower == "hkey_current_config" or lower == "current_config":
  140. return HKEY_CURRENT_CONFIG
  141. elif lower == "hkey_local_machine" or lower == "local_machine":
  142. return HKEY_LOCAL_MACHINE
  143. elif lower == "hkey_users" or lower == "users":
  144. return HKEY_USERS
  145. else:
  146. raise ValueError(f"Hive: {hive_str} not recognized")