util.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. // package for logicing client and server code
  2. package logic
  3. import (
  4. "crypto/rand"
  5. "encoding/base32"
  6. "encoding/base64"
  7. "encoding/json"
  8. "fmt"
  9. "log/slog"
  10. "net"
  11. "net/http"
  12. "os"
  13. "reflect"
  14. "strings"
  15. "time"
  16. "unicode"
  17. "github.com/blang/semver"
  18. "github.com/c-robinson/iplib"
  19. "github.com/gravitl/netmaker/database"
  20. "github.com/gravitl/netmaker/logger"
  21. )
  22. // IsBase64 - checks if a string is in base64 format
  23. // This is used to validate public keys (make sure they're base64 encoded like all public keys should be).
  24. func IsBase64(s string) bool {
  25. _, err := base64.StdEncoding.DecodeString(s)
  26. return err == nil
  27. }
  28. // CheckEndpoint - checks if an endpoint is valid
  29. func CheckEndpoint(endpoint string) bool {
  30. endpointarr := strings.Split(endpoint, ":")
  31. return len(endpointarr) == 2
  32. }
  33. // FileExists - checks if local file exists
  34. func FileExists(f string) bool {
  35. info, err := os.Stat(f)
  36. if os.IsNotExist(err) {
  37. return false
  38. }
  39. return !info.IsDir()
  40. }
  41. // IsAddressInCIDR - util to see if an address is in a cidr or not
  42. func IsAddressInCIDR(address net.IP, cidr string) bool {
  43. var _, currentCIDR, cidrErr = net.ParseCIDR(cidr)
  44. if cidrErr != nil {
  45. return false
  46. }
  47. return currentCIDR.Contains(address)
  48. }
  49. // SetNetworkNodesLastModified - sets the network nodes last modified
  50. func SetNetworkNodesLastModified(networkName string) error {
  51. timestamp := time.Now().Unix()
  52. network, err := GetParentNetwork(networkName)
  53. if err != nil {
  54. return err
  55. }
  56. network.NodesLastModified = timestamp
  57. data, err := json.Marshal(&network)
  58. if err != nil {
  59. return err
  60. }
  61. err = database.Insert(networkName, string(data), database.NETWORKS_TABLE_NAME)
  62. if err != nil {
  63. return err
  64. }
  65. return nil
  66. }
  67. // RandomString - returns a random string in a charset
  68. func RandomString(length int) string {
  69. randombytes := make([]byte, length)
  70. _, err := rand.Read(randombytes)
  71. if err != nil {
  72. logger.Log(0, "random string", err.Error())
  73. return ""
  74. }
  75. return base32.StdEncoding.EncodeToString(randombytes)[:length]
  76. }
  77. // StringSliceContains - sees if a string slice contains a string element
  78. func StringSliceContains(slice []string, item string) bool {
  79. for _, s := range slice {
  80. if s == item {
  81. return true
  82. }
  83. }
  84. return false
  85. }
  86. func SetVerbosity(logLevel int) {
  87. var level slog.Level
  88. switch logLevel {
  89. case 0:
  90. level = slog.LevelInfo
  91. case 1:
  92. level = slog.LevelError
  93. case 2:
  94. level = slog.LevelWarn
  95. case 3:
  96. level = slog.LevelDebug
  97. default:
  98. level = slog.LevelInfo
  99. }
  100. // Create the logger with the chosen level
  101. handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
  102. Level: level,
  103. })
  104. logger := slog.New(handler)
  105. slog.SetDefault(logger)
  106. }
  107. // NormalizeCIDR - returns the first address of CIDR
  108. func NormalizeCIDR(address string) (string, error) {
  109. ip, IPNet, err := net.ParseCIDR(address)
  110. if err != nil {
  111. return "", err
  112. }
  113. if ip.To4() == nil {
  114. net6 := iplib.Net6FromStr(IPNet.String())
  115. IPNet.IP = net6.FirstAddress()
  116. } else {
  117. net4 := iplib.Net4FromStr(IPNet.String())
  118. IPNet.IP = net4.NetworkAddress()
  119. }
  120. return IPNet.String(), nil
  121. }
  122. // StringDifference - returns the elements in `a` that aren't in `b`.
  123. func StringDifference(a, b []string) []string {
  124. mb := make(map[string]struct{}, len(b))
  125. for _, x := range b {
  126. mb[x] = struct{}{}
  127. }
  128. var diff []string
  129. for _, x := range a {
  130. if _, found := mb[x]; !found {
  131. diff = append(diff, x)
  132. }
  133. }
  134. return diff
  135. }
  136. // CheckIfFileExists - checks if file exists or not in the given path
  137. func CheckIfFileExists(filePath string) bool {
  138. if _, err := os.Stat(filePath); os.IsNotExist(err) {
  139. return false
  140. }
  141. return true
  142. }
  143. // RemoveStringSlice - removes an element at given index i
  144. // from a given string slice
  145. func RemoveStringSlice(slice []string, i int) []string {
  146. return append(slice[:i], slice[i+1:]...)
  147. }
  148. // IsSlicesEqual tells whether a and b contain the same elements.
  149. // A nil argument is equivalent to an empty slice.
  150. func IsSlicesEqual(a, b []string) bool {
  151. if len(a) != len(b) {
  152. return false
  153. }
  154. for i, v := range a {
  155. if v != b[i] {
  156. return false
  157. }
  158. }
  159. return true
  160. }
  161. // VersionLessThan checks if v1 < v2 semantically
  162. // dev is the latest version
  163. func VersionLessThan(v1, v2 string) (bool, error) {
  164. if v1 == "dev" {
  165. return false, nil
  166. }
  167. if v2 == "dev" {
  168. return true, nil
  169. }
  170. semVer1 := strings.TrimFunc(v1, func(r rune) bool {
  171. return !unicode.IsNumber(r)
  172. })
  173. semVer2 := strings.TrimFunc(v2, func(r rune) bool {
  174. return !unicode.IsNumber(r)
  175. })
  176. sv1, err := semver.Parse(semVer1)
  177. if err != nil {
  178. return false, fmt.Errorf("failed to parse semver1 (%s): %w", semVer1, err)
  179. }
  180. sv2, err := semver.Parse(semVer2)
  181. if err != nil {
  182. return false, fmt.Errorf("failed to parse semver2 (%s): %w", semVer2, err)
  183. }
  184. return sv1.LT(sv2), nil
  185. }
  186. // Compare any two maps with any key and value types
  187. func CompareMaps[K comparable, V any](a, b map[K]V) bool {
  188. if len(a) != len(b) {
  189. return false
  190. }
  191. for key, valA := range a {
  192. valB, ok := b[key]
  193. if !ok {
  194. return false
  195. }
  196. if !reflect.DeepEqual(valA, valB) {
  197. return false
  198. }
  199. }
  200. return true
  201. }
  202. func UniqueStrings(input []string) []string {
  203. seen := make(map[string]struct{})
  204. var result []string
  205. for _, val := range input {
  206. if _, ok := seen[val]; !ok {
  207. seen[val] = struct{}{}
  208. result = append(result, val)
  209. }
  210. }
  211. return result
  212. }
  213. func GetClientIP(r *http.Request) string {
  214. // Trust X-Forwarded-For first
  215. if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
  216. parts := strings.Split(xff, ",")
  217. return strings.TrimSpace(parts[0])
  218. }
  219. if xrip := r.Header.Get("X-Real-IP"); xrip != "" {
  220. return xrip
  221. }
  222. ip, _, err := net.SplitHostPort(r.RemoteAddr)
  223. if err != nil {
  224. return r.RemoteAddr
  225. }
  226. return ip
  227. }