dns.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. package logic
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "os"
  7. "regexp"
  8. "sort"
  9. "strings"
  10. validator "github.com/go-playground/validator/v10"
  11. "github.com/gravitl/netmaker/database"
  12. "github.com/gravitl/netmaker/logger"
  13. "github.com/gravitl/netmaker/models"
  14. "github.com/gravitl/netmaker/schema"
  15. "github.com/txn2/txeh"
  16. )
  17. // SetDNS - sets the dns on file
  18. func SetDNS() error {
  19. hostfile, err := txeh.NewHosts(&txeh.HostsConfig{})
  20. if err != nil {
  21. return err
  22. }
  23. var corefilestring string
  24. networks, err := GetNetworks()
  25. if err != nil && !database.IsEmptyRecord(err) {
  26. return err
  27. }
  28. for _, net := range networks {
  29. corefilestring = corefilestring + net.NetID + " "
  30. dns, err := GetDNS(net.NetID)
  31. if err != nil && !database.IsEmptyRecord(err) {
  32. return err
  33. }
  34. for _, entry := range dns {
  35. hostfile.AddHost(entry.Address, entry.Name)
  36. }
  37. }
  38. dns := GetExtclientDNS()
  39. for _, entry := range dns {
  40. hostfile.AddHost(entry.Address, entry.Name)
  41. }
  42. if corefilestring == "" {
  43. corefilestring = "example.com"
  44. }
  45. err = hostfile.SaveAs("./config/dnsconfig/netmaker.hosts")
  46. if err != nil {
  47. return err
  48. }
  49. /* if something goes wrong with server DNS, check here
  50. // commented out bc we were not using IsSplitDNS
  51. if servercfg.IsSplitDNS() {
  52. err = SetCorefile(corefilestring)
  53. }
  54. */
  55. return err
  56. }
  57. // GetDNS - gets the DNS of a current network
  58. func GetDNS(network string) ([]models.DNSEntry, error) {
  59. dns, err := GetNodeDNS(network)
  60. if err != nil && !database.IsEmptyRecord(err) {
  61. return dns, err
  62. }
  63. customdns, err := GetCustomDNS(network)
  64. if err != nil && !database.IsEmptyRecord(err) {
  65. return dns, err
  66. }
  67. dns = append(dns, customdns...)
  68. return dns, nil
  69. }
  70. // GetExtclientDNS - gets all extclients dns entries
  71. func GetExtclientDNS() []models.DNSEntry {
  72. extclients, err := GetAllExtClients()
  73. if err != nil {
  74. return []models.DNSEntry{}
  75. }
  76. var dns []models.DNSEntry
  77. for _, extclient := range extclients {
  78. var entry = models.DNSEntry{}
  79. entry.Name = fmt.Sprintf("%s.%s", extclient.ClientID, extclient.Network)
  80. entry.Network = extclient.Network
  81. if extclient.Address != "" {
  82. entry.Address = extclient.Address
  83. }
  84. if extclient.Address6 != "" {
  85. entry.Address6 = extclient.Address6
  86. }
  87. dns = append(dns, entry)
  88. }
  89. return dns
  90. }
  91. // GetNodeDNS - gets the DNS of a network node
  92. func GetNodeDNS(network string) ([]models.DNSEntry, error) {
  93. var dns []models.DNSEntry
  94. nodes, err := GetNetworkNodes(network)
  95. if err != nil {
  96. return dns, err
  97. }
  98. defaultDomain := GetDefaultDomain()
  99. for _, node := range nodes {
  100. if node.Network != network {
  101. continue
  102. }
  103. host, err := GetHost(node.HostID.String())
  104. if err != nil {
  105. continue
  106. }
  107. var entry = models.DNSEntry{}
  108. if defaultDomain == "" {
  109. entry.Name = fmt.Sprintf("%s.%s", host.Name, network)
  110. } else {
  111. entry.Name = fmt.Sprintf("%s.%s.%s", host.Name, network, defaultDomain)
  112. }
  113. entry.Network = network
  114. if node.Address.IP != nil {
  115. entry.Address = node.Address.IP.String()
  116. }
  117. if node.Address6.IP != nil {
  118. entry.Address6 = node.Address6.IP.String()
  119. }
  120. dns = append(dns, entry)
  121. }
  122. return dns, nil
  123. }
  124. // GetCustomDNS - gets the custom DNS of a network
  125. func GetCustomDNS(network string) ([]models.DNSEntry, error) {
  126. var dns []models.DNSEntry
  127. collection, err := database.FetchRecords(database.DNS_TABLE_NAME)
  128. if err != nil {
  129. return dns, err
  130. }
  131. for _, value := range collection { // filter for entries based on network
  132. var entry models.DNSEntry
  133. if err := json.Unmarshal([]byte(value), &entry); err != nil {
  134. continue
  135. }
  136. if entry.Network == network {
  137. dns = append(dns, entry)
  138. }
  139. }
  140. return dns, err
  141. }
  142. // SetCorefile - sets the core file of the system
  143. func SetCorefile(domains string) error {
  144. dir, err := os.Getwd()
  145. if err != nil {
  146. return err
  147. }
  148. err = os.MkdirAll(dir+"/config/dnsconfig", 0744)
  149. if err != nil {
  150. logger.Log(0, "couldnt find or create /config/dnsconfig")
  151. return err
  152. }
  153. corefile := domains + ` {
  154. reload 15s
  155. hosts /root/dnsconfig/netmaker.hosts {
  156. fallthrough
  157. }
  158. forward . 8.8.8.8 8.8.4.4
  159. log
  160. }
  161. `
  162. err = os.WriteFile(dir+"/config/dnsconfig/Corefile", []byte(corefile), 0644)
  163. if err != nil {
  164. return err
  165. }
  166. return err
  167. }
  168. // GetAllDNS - gets all dns entries
  169. func GetAllDNS() ([]models.DNSEntry, error) {
  170. var dns []models.DNSEntry
  171. networks, err := GetNetworks()
  172. if err != nil && !database.IsEmptyRecord(err) {
  173. return []models.DNSEntry{}, err
  174. }
  175. for _, net := range networks {
  176. netdns, err := GetDNS(net.NetID)
  177. if err != nil {
  178. return []models.DNSEntry{}, nil
  179. }
  180. dns = append(dns, netdns...)
  181. }
  182. return dns, nil
  183. }
  184. // GetDNSEntryNum - gets which entry the dns was
  185. func GetDNSEntryNum(domain string, network string) (int, error) {
  186. num := 0
  187. entries, err := GetDNS(network)
  188. if err != nil {
  189. return 0, err
  190. }
  191. for i := 0; i < len(entries); i++ {
  192. if domain == entries[i].Name {
  193. num++
  194. }
  195. }
  196. return num, nil
  197. }
  198. // SortDNSEntrys - Sorts slice of DNSEnteys by their Address alphabetically with numbers first
  199. func SortDNSEntrys(unsortedDNSEntrys []models.DNSEntry) {
  200. sort.Slice(unsortedDNSEntrys, func(i, j int) bool {
  201. return unsortedDNSEntrys[i].Address < unsortedDNSEntrys[j].Address
  202. })
  203. }
  204. // IsNetworkNameValid - checks if a netid of a network uses valid characters
  205. func IsDNSEntryValid(d string) bool {
  206. re := regexp.MustCompile(`^[A-Za-z0-9-.]+$`)
  207. return re.MatchString(d)
  208. }
  209. // ValidateDNSCreate - checks if an entry is valid
  210. func ValidateDNSCreate(entry models.DNSEntry) error {
  211. if !IsDNSEntryValid(entry.Name) {
  212. return errors.New("invalid input. Only uppercase letters (A-Z), lowercase letters (a-z), numbers (0-9), minus sign (-) and dots (.) are allowed")
  213. }
  214. v := validator.New()
  215. _ = v.RegisterValidation("whitespace", func(f1 validator.FieldLevel) bool {
  216. match, _ := regexp.MatchString(`\s`, entry.Name)
  217. return !match
  218. })
  219. _ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
  220. num, err := GetDNSEntryNum(entry.Name, entry.Network)
  221. return err == nil && num == 0
  222. })
  223. _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
  224. _, err := GetParentNetwork(entry.Network)
  225. return err == nil
  226. })
  227. err := v.Struct(entry)
  228. if err != nil {
  229. for _, e := range err.(validator.ValidationErrors) {
  230. logger.Log(1, e.Error())
  231. }
  232. }
  233. return err
  234. }
  235. // ValidateDNSUpdate - validates a DNS update
  236. func ValidateDNSUpdate(change models.DNSEntry, entry models.DNSEntry) error {
  237. v := validator.New()
  238. _ = v.RegisterValidation("whitespace", func(f1 validator.FieldLevel) bool {
  239. match, _ := regexp.MatchString(`\s`, entry.Name)
  240. return !match
  241. })
  242. _ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
  243. //if name & net not changing name we are good
  244. if change.Name == entry.Name && change.Network == entry.Network {
  245. return true
  246. }
  247. num, err := GetDNSEntryNum(change.Name, change.Network)
  248. return err == nil && num == 0
  249. })
  250. _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
  251. _, err := GetParentNetwork(change.Network)
  252. return err == nil
  253. })
  254. err := v.Struct(change)
  255. if err != nil {
  256. for _, e := range err.(validator.ValidationErrors) {
  257. logger.Log(1, e.Error())
  258. }
  259. }
  260. return err
  261. }
  262. // DeleteDNS - deletes a DNS entry
  263. func DeleteDNS(domain string, network string) error {
  264. key, err := GetRecordKey(domain, network)
  265. if err != nil {
  266. return err
  267. }
  268. err = database.DeleteRecord(database.DNS_TABLE_NAME, key)
  269. return err
  270. }
  271. // CreateDNS - creates a DNS entry
  272. func CreateDNS(entry models.DNSEntry) (models.DNSEntry, error) {
  273. k, err := GetRecordKey(entry.Name, entry.Network)
  274. if err != nil {
  275. return models.DNSEntry{}, err
  276. }
  277. data, err := json.Marshal(&entry)
  278. if err != nil {
  279. return models.DNSEntry{}, err
  280. }
  281. err = database.Insert(k, string(data), database.DNS_TABLE_NAME)
  282. return entry, err
  283. }
  284. func ValidateNameserverReq(ns models.NameserverReq) error {
  285. if ns.Name == "" {
  286. return errors.New("name is required")
  287. }
  288. if len(ns.Servers) == 0 {
  289. return errors.New("atleast one nameserver should be specified")
  290. }
  291. if !IsValidMatchDomain(ns.MatchDomain) {
  292. return errors.New("invalid match domain")
  293. }
  294. return nil
  295. }
  296. func ValidateUpdateNameserverReq(updateNs schema.Nameserver) error {
  297. if updateNs.Name == "" {
  298. return errors.New("name is required")
  299. }
  300. if len(updateNs.Servers) == 0 {
  301. return errors.New("atleast one nameserver should be specified")
  302. }
  303. if !IsValidMatchDomain(updateNs.MatchDomain) {
  304. return errors.New("invalid match domain")
  305. }
  306. return nil
  307. }
  308. // IsValidMatchDomain reports whether s is a valid "match domain".
  309. // Rules (simple/ASCII):
  310. // - "~." is allowed (match all).
  311. // - Optional leading "~" allowed (e.g., "~example.com").
  312. // - Optional single trailing "." allowed (FQDN form).
  313. // - No wildcards "*", no leading ".", no underscores.
  314. // - Labels: letters/digits/hyphen (LDH), 1–63 chars, no leading/trailing hyphen.
  315. // - Total length (without trailing dot) ≤ 253.
  316. func IsValidMatchDomain(s string) bool {
  317. s = strings.TrimSpace(s)
  318. if s == "" {
  319. return false
  320. }
  321. if s == "~." { // special case: match-all
  322. return true
  323. }
  324. // Strip optional leading "~"
  325. if strings.HasPrefix(s, "~") {
  326. s = s[1:]
  327. if s == "" {
  328. return false
  329. }
  330. }
  331. // Allow exactly one trailing dot
  332. if strings.HasSuffix(s, ".") {
  333. s = s[:len(s)-1]
  334. if s == "" {
  335. return false
  336. }
  337. }
  338. // Disallow leading dot, wildcards, underscores
  339. if strings.HasPrefix(s, ".") || strings.Contains(s, "*") || strings.Contains(s, "_") {
  340. return false
  341. }
  342. // Lowercase for ASCII checks
  343. s = strings.ToLower(s)
  344. // Length check
  345. if len(s) > 253 {
  346. return false
  347. }
  348. // Label regex: LDH, 1–63, no leading/trailing hyphen
  349. reLabel := regexp.MustCompile(`^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?$`)
  350. parts := strings.Split(s, ".")
  351. for _, lbl := range parts {
  352. if len(lbl) == 0 || len(lbl) > 63 {
  353. return false
  354. }
  355. if !reLabel.MatchString(lbl) {
  356. return false
  357. }
  358. }
  359. return true
  360. }