util.go 6.1 KB


  1. // package for logicing client and server code
  2. package logic
  3. import (
  4. "crypto/rand"
  5. "encoding/base32"
  6. "encoding/base64"
  7. "encoding/json"
  8. "fmt"
  9. "log/slog"
  10. "net"
  11. "net/http"
  12. "os"
  13. "reflect"
  14. "strings"
  15. "time"
  16. "unicode"
  17. "github.com/blang/semver"
  18. "github.com/c-robinson/iplib"
  19. "github.com/gravitl/netmaker/database"
  20. "github.com/gravitl/netmaker/logger"
  21. "github.com/gravitl/netmaker/models"
  22. )
  23. // IsBase64 - checks if a string is in base64 format
  24. // This is used to validate public keys (make sure they're base64 encoded like all public keys should be).
  25. func IsBase64(s string) bool {
  26. _, err := base64.StdEncoding.DecodeString(s)
  27. return err == nil
  28. }
  29. // CheckEndpoint - checks if an endpoint is valid
  30. func CheckEndpoint(endpoint string) bool {
  31. endpointarr := strings.Split(endpoint, ":")
  32. return len(endpointarr) == 2
  33. }
  34. // FileExists - checks if local file exists
  35. func FileExists(f string) bool {
  36. info, err := os.Stat(f)
  37. if os.IsNotExist(err) {
  38. return false
  39. }
  40. return !info.IsDir()
  41. }
  42. // IsAddressInCIDR - util to see if an address is in a cidr or not
  43. func IsAddressInCIDR(address net.IP, cidr string) bool {
  44. var _, currentCIDR, cidrErr = net.ParseCIDR(cidr)
  45. if cidrErr != nil {
  46. return false
  47. }
  48. return currentCIDR.Contains(address)
  49. }
  50. // SetNetworkNodesLastModified - sets the network nodes last modified
  51. func SetNetworkNodesLastModified(networkName string) error {
  52. timestamp := time.Now().Unix()
  53. network, err := GetParentNetwork(networkName)
  54. if err != nil {
  55. return err
  56. }
  57. network.NodesLastModified = timestamp
  58. data, err := json.Marshal(&network)
  59. if err != nil {
  60. return err
  61. }
  62. err = database.Insert(networkName, string(data), database.NETWORKS_TABLE_NAME)
  63. if err != nil {
  64. return err
  65. }
  66. return nil
  67. }
  68. // RandomString - returns a random string in a charset
  69. func RandomString(length int) string {
  70. randombytes := make([]byte, length)
  71. _, err := rand.Read(randombytes)
  72. if err != nil {
  73. logger.Log(0, "random string", err.Error())
  74. return ""
  75. }
  76. return base32.StdEncoding.EncodeToString(randombytes)[:length]
  77. }
  78. // StringSliceContains - sees if a string slice contains a string element
  79. func StringSliceContains(slice []string, item string) bool {
  80. for _, s := range slice {
  81. if s == item {
  82. return true
  83. }
  84. }
  85. return false
  86. }
  87. func SetVerbosity(logLevel int) {
  88. var level slog.Level
  89. switch logLevel {
  90. case 0:
  91. level = slog.LevelInfo
  92. case 1:
  93. level = slog.LevelError
  94. case 2:
  95. level = slog.LevelWarn
  96. case 3:
  97. level = slog.LevelDebug
  98. default:
  99. level = slog.LevelInfo
  100. }
  101. // Create the logger with the chosen level
  102. handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
  103. Level: level,
  104. })
  105. logger := slog.New(handler)
  106. slog.SetDefault(logger)
  107. }
  108. // NormalizeCIDR - returns the first address of CIDR
  109. func NormalizeCIDR(address string) (string, error) {
  110. ip, IPNet, err := net.ParseCIDR(address)
  111. if err != nil {
  112. return "", err
  113. }
  114. if ip.To4() == nil {
  115. net6 := iplib.Net6FromStr(IPNet.String())
  116. IPNet.IP = net6.FirstAddress()
  117. } else {
  118. net4 := iplib.Net4FromStr(IPNet.String())
  119. IPNet.IP = net4.NetworkAddress()
  120. }
  121. return IPNet.String(), nil
  122. }
  123. // StringDifference - returns the elements in `a` that aren't in `b`.
  124. func StringDifference(a, b []string) []string {
  125. mb := make(map[string]struct{}, len(b))
  126. for _, x := range b {
  127. mb[x] = struct{}{}
  128. }
  129. var diff []string
  130. for _, x := range a {
  131. if _, found := mb[x]; !found {
  132. diff = append(diff, x)
  133. }
  134. }
  135. return diff
  136. }
  137. // CheckIfFileExists - checks if file exists or not in the given path
  138. func CheckIfFileExists(filePath string) bool {
  139. if _, err := os.Stat(filePath); os.IsNotExist(err) {
  140. return false
  141. }
  142. return true
  143. }
  144. // RemoveStringSlice - removes an element at given index i
  145. // from a given string slice
  146. func RemoveStringSlice(slice []string, i int) []string {
  147. return append(slice[:i], slice[i+1:]...)
  148. }
  149. // IsSlicesEqual tells whether a and b contain the same elements.
  150. // A nil argument is equivalent to an empty slice.
  151. func IsSlicesEqual(a, b []string) bool {
  152. if len(a) != len(b) {
  153. return false
  154. }
  155. for i, v := range a {
  156. if v != b[i] {
  157. return false
  158. }
  159. }
  160. return true
  161. }
  162. // VersionLessThan checks if v1 < v2 semantically
  163. // dev is the latest version
  164. func VersionLessThan(v1, v2 string) (bool, error) {
  165. if v1 == "dev" {
  166. return false, nil
  167. }
  168. if v2 == "dev" {
  169. return true, nil
  170. }
  171. semVer1 := strings.TrimFunc(v1, func(r rune) bool {
  172. return !unicode.IsNumber(r)
  173. })
  174. semVer2 := strings.TrimFunc(v2, func(r rune) bool {
  175. return !unicode.IsNumber(r)
  176. })
  177. sv1, err := semver.Parse(semVer1)
  178. if err != nil {
  179. return false, fmt.Errorf("failed to parse semver1 (%s): %w", semVer1, err)
  180. }
  181. sv2, err := semver.Parse(semVer2)
  182. if err != nil {
  183. return false, fmt.Errorf("failed to parse semver2 (%s): %w", semVer2, err)
  184. }
  185. return sv1.LT(sv2), nil
  186. }
  187. // Compare any two maps with any key and value types
  188. func CompareMaps[K comparable, V any](a, b map[K]V) bool {
  189. if len(a) != len(b) {
  190. return false
  191. }
  192. for key, valA := range a {
  193. valB, ok := b[key]
  194. if !ok {
  195. return false
  196. }
  197. if !reflect.DeepEqual(valA, valB) {
  198. return false
  199. }
  200. }
  201. return true
  202. }
  203. func UniqueStrings(input []string) []string {
  204. seen := make(map[string]struct{})
  205. var result []string
  206. for _, val := range input {
  207. if _, ok := seen[val]; !ok {
  208. seen[val] = struct{}{}
  209. result = append(result, val)
  210. }
  211. }
  212. return result
  213. }
  214. func GetClientIP(r *http.Request) string {
  215. // Trust X-Forwarded-For first
  216. if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
  217. parts := strings.Split(xff, ",")
  218. return strings.TrimSpace(parts[0])
  219. }
  220. if xrip := r.Header.Get("X-Real-IP"); xrip != "" {
  221. return xrip
  222. }
  223. ip, _, err := net.SplitHostPort(r.RemoteAddr)
  224. if err != nil {
  225. return r.RemoteAddr
  226. }
  227. return ip
  228. }
  229. // CompareIfaceSlices compares two slices of Iface for deep equality (order-sensitive)
  230. func CompareIfaceSlices(a, b []models.Iface) bool {
  231. if len(a) != len(b) {
  232. return false
  233. }
  234. for i := range a {
  235. if !compareIface(a[i], b[i]) {
  236. return false
  237. }
  238. }
  239. return true
  240. }
  241. func compareIface(a, b models.Iface) bool {
  242. return a.Name == b.Name &&
  243. a.Address.IP.Equal(b.Address.IP) &&
  244. a.Address.Mask.String() == b.Address.Mask.String() &&
  245. a.AddressString == b.AddressString
  246. }