dns.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. package logic
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "os"
  8. "regexp"
  9. "sort"
  10. "strings"
  11. validator "github.com/go-playground/validator/v10"
  12. "github.com/gravitl/netmaker/database"
  13. "github.com/gravitl/netmaker/db"
  14. "github.com/gravitl/netmaker/logger"
  15. "github.com/gravitl/netmaker/models"
  16. "github.com/gravitl/netmaker/schema"
  17. "github.com/txn2/txeh"
  18. )
  19. type GlobalNs struct {
  20. ID string `json:"id"`
  21. IPs []string `json:"ips"`
  22. }
  23. var GlobalNsList = map[string]GlobalNs{
  24. "Google": {
  25. ID: "Google",
  26. IPs: []string{
  27. "8.8.8.8",
  28. "8.8.4.4",
  29. "2001:4860:4860::8888",
  30. "2001:4860:4860::8844",
  31. },
  32. },
  33. "Cloudflare": {
  34. ID: "Cloudflare",
  35. IPs: []string{
  36. "1.1.1.1",
  37. "1.0.0.1",
  38. "2606:4700:4700::1111",
  39. "2606:4700:4700::1001",
  40. },
  41. },
  42. "Quad9": {
  43. ID: "Quad9",
  44. IPs: []string{
  45. "9.9.9.9",
  46. "149.112.112.112",
  47. "2620:fe::fe",
  48. "2620:fe::9",
  49. },
  50. },
  51. }
  52. // SetDNS - sets the dns on file
  53. func SetDNS() error {
  54. hostfile, err := txeh.NewHosts(&txeh.HostsConfig{})
  55. if err != nil {
  56. return err
  57. }
  58. var corefilestring string
  59. networks, err := GetNetworks()
  60. if err != nil && !database.IsEmptyRecord(err) {
  61. return err
  62. }
  63. for _, net := range networks {
  64. corefilestring = corefilestring + net.NetID + " "
  65. dns, err := GetDNS(net.NetID)
  66. if err != nil && !database.IsEmptyRecord(err) {
  67. return err
  68. }
  69. for _, entry := range dns {
  70. hostfile.AddHost(entry.Address, entry.Name)
  71. }
  72. }
  73. dns := GetExtclientDNS()
  74. for _, entry := range dns {
  75. hostfile.AddHost(entry.Address, entry.Name)
  76. }
  77. if corefilestring == "" {
  78. corefilestring = "example.com"
  79. }
  80. err = hostfile.SaveAs("./config/dnsconfig/netmaker.hosts")
  81. if err != nil {
  82. return err
  83. }
  84. /* if something goes wrong with server DNS, check here
  85. // commented out bc we were not using IsSplitDNS
  86. if servercfg.IsSplitDNS() {
  87. err = SetCorefile(corefilestring)
  88. }
  89. */
  90. return err
  91. }
  92. // GetDNS - gets the DNS of a current network
  93. func GetDNS(network string) ([]models.DNSEntry, error) {
  94. dns, err := GetNodeDNS(network)
  95. if err != nil && !database.IsEmptyRecord(err) {
  96. return dns, err
  97. }
  98. customdns, err := GetCustomDNS(network)
  99. if err != nil && !database.IsEmptyRecord(err) {
  100. return dns, err
  101. }
  102. dns = append(dns, customdns...)
  103. return dns, nil
  104. }
  105. // GetExtclientDNS - gets all extclients dns entries
  106. func GetExtclientDNS() []models.DNSEntry {
  107. extclients, err := GetAllExtClients()
  108. if err != nil {
  109. return []models.DNSEntry{}
  110. }
  111. var dns []models.DNSEntry
  112. for _, extclient := range extclients {
  113. var entry = models.DNSEntry{}
  114. entry.Name = fmt.Sprintf("%s.%s", extclient.ClientID, extclient.Network)
  115. entry.Network = extclient.Network
  116. if extclient.Address != "" {
  117. entry.Address = extclient.Address
  118. }
  119. if extclient.Address6 != "" {
  120. entry.Address6 = extclient.Address6
  121. }
  122. dns = append(dns, entry)
  123. }
  124. return dns
  125. }
  126. // GetNodeDNS - gets the DNS of a network node
  127. func GetNodeDNS(network string) ([]models.DNSEntry, error) {
  128. var dns []models.DNSEntry
  129. nodes, err := GetNetworkNodes(network)
  130. if err != nil {
  131. return dns, err
  132. }
  133. defaultDomain := GetDefaultDomain()
  134. for _, node := range nodes {
  135. if node.Network != network {
  136. continue
  137. }
  138. host, err := GetHost(node.HostID.String())
  139. if err != nil {
  140. continue
  141. }
  142. var entry = models.DNSEntry{}
  143. if defaultDomain == "" {
  144. entry.Name = fmt.Sprintf("%s.%s", host.Name, network)
  145. } else {
  146. entry.Name = fmt.Sprintf("%s.%s.%s", host.Name, network, defaultDomain)
  147. }
  148. entry.Network = network
  149. if node.Address.IP != nil {
  150. entry.Address = node.Address.IP.String()
  151. }
  152. if node.Address6.IP != nil {
  153. entry.Address6 = node.Address6.IP.String()
  154. }
  155. dns = append(dns, entry)
  156. }
  157. return dns, nil
  158. }
  159. // GetCustomDNS - gets the custom DNS of a network
  160. func GetCustomDNS(network string) ([]models.DNSEntry, error) {
  161. var dns []models.DNSEntry
  162. collection, err := database.FetchRecords(database.DNS_TABLE_NAME)
  163. if err != nil {
  164. return dns, err
  165. }
  166. for _, value := range collection { // filter for entries based on network
  167. var entry models.DNSEntry
  168. if err := json.Unmarshal([]byte(value), &entry); err != nil {
  169. continue
  170. }
  171. if entry.Network == network {
  172. dns = append(dns, entry)
  173. }
  174. }
  175. return dns, err
  176. }
  177. // SetCorefile - sets the core file of the system
  178. func SetCorefile(domains string) error {
  179. dir, err := os.Getwd()
  180. if err != nil {
  181. return err
  182. }
  183. err = os.MkdirAll(dir+"/config/dnsconfig", 0744)
  184. if err != nil {
  185. logger.Log(0, "couldnt find or create /config/dnsconfig")
  186. return err
  187. }
  188. corefile := domains + ` {
  189. reload 15s
  190. hosts /root/dnsconfig/netmaker.hosts {
  191. fallthrough
  192. }
  193. forward . 8.8.8.8 8.8.4.4
  194. log
  195. }
  196. `
  197. err = os.WriteFile(dir+"/config/dnsconfig/Corefile", []byte(corefile), 0644)
  198. if err != nil {
  199. return err
  200. }
  201. return err
  202. }
  203. // GetAllDNS - gets all dns entries
  204. func GetAllDNS() ([]models.DNSEntry, error) {
  205. var dns []models.DNSEntry
  206. networks, err := GetNetworks()
  207. if err != nil && !database.IsEmptyRecord(err) {
  208. return []models.DNSEntry{}, err
  209. }
  210. for _, net := range networks {
  211. netdns, err := GetDNS(net.NetID)
  212. if err != nil {
  213. return []models.DNSEntry{}, nil
  214. }
  215. dns = append(dns, netdns...)
  216. }
  217. return dns, nil
  218. }
  219. // GetDNSEntryNum - gets which entry the dns was
  220. func GetDNSEntryNum(domain string, network string) (int, error) {
  221. num := 0
  222. entries, err := GetDNS(network)
  223. if err != nil {
  224. return 0, err
  225. }
  226. for i := 0; i < len(entries); i++ {
  227. if domain == entries[i].Name {
  228. num++
  229. }
  230. }
  231. return num, nil
  232. }
  233. // SortDNSEntrys - Sorts slice of DNSEnteys by their Address alphabetically with numbers first
  234. func SortDNSEntrys(unsortedDNSEntrys []models.DNSEntry) {
  235. sort.Slice(unsortedDNSEntrys, func(i, j int) bool {
  236. return unsortedDNSEntrys[i].Address < unsortedDNSEntrys[j].Address
  237. })
  238. }
  239. // IsNetworkNameValid - checks if a netid of a network uses valid characters
  240. func IsDNSEntryValid(d string) bool {
  241. re := regexp.MustCompile(`^[A-Za-z0-9-.]+$`)
  242. return re.MatchString(d)
  243. }
  244. // ValidateDNSCreate - checks if an entry is valid
  245. func ValidateDNSCreate(entry models.DNSEntry) error {
  246. if !IsDNSEntryValid(entry.Name) {
  247. return errors.New("invalid input. Only uppercase letters (A-Z), lowercase letters (a-z), numbers (0-9), minus sign (-) and dots (.) are allowed")
  248. }
  249. v := validator.New()
  250. _ = v.RegisterValidation("whitespace", func(f1 validator.FieldLevel) bool {
  251. match, _ := regexp.MatchString(`\s`, entry.Name)
  252. return !match
  253. })
  254. _ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
  255. num, err := GetDNSEntryNum(entry.Name, entry.Network)
  256. return err == nil && num == 0
  257. })
  258. _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
  259. _, err := GetParentNetwork(entry.Network)
  260. return err == nil
  261. })
  262. err := v.Struct(entry)
  263. if err != nil {
  264. for _, e := range err.(validator.ValidationErrors) {
  265. logger.Log(1, e.Error())
  266. }
  267. }
  268. return err
  269. }
  270. // ValidateDNSUpdate - validates a DNS update
  271. func ValidateDNSUpdate(change models.DNSEntry, entry models.DNSEntry) error {
  272. v := validator.New()
  273. _ = v.RegisterValidation("whitespace", func(f1 validator.FieldLevel) bool {
  274. match, _ := regexp.MatchString(`\s`, entry.Name)
  275. return !match
  276. })
  277. _ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
  278. //if name & net not changing name we are good
  279. if change.Name == entry.Name && change.Network == entry.Network {
  280. return true
  281. }
  282. num, err := GetDNSEntryNum(change.Name, change.Network)
  283. return err == nil && num == 0
  284. })
  285. _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
  286. _, err := GetParentNetwork(change.Network)
  287. return err == nil
  288. })
  289. err := v.Struct(change)
  290. if err != nil {
  291. for _, e := range err.(validator.ValidationErrors) {
  292. logger.Log(1, e.Error())
  293. }
  294. }
  295. return err
  296. }
  297. // DeleteDNS - deletes a DNS entry
  298. func DeleteDNS(domain string, network string) error {
  299. key, err := GetRecordKey(domain, network)
  300. if err != nil {
  301. return err
  302. }
  303. err = database.DeleteRecord(database.DNS_TABLE_NAME, key)
  304. return err
  305. }
  306. // CreateDNS - creates a DNS entry
  307. func CreateDNS(entry models.DNSEntry) (models.DNSEntry, error) {
  308. k, err := GetRecordKey(entry.Name, entry.Network)
  309. if err != nil {
  310. return models.DNSEntry{}, err
  311. }
  312. data, err := json.Marshal(&entry)
  313. if err != nil {
  314. return models.DNSEntry{}, err
  315. }
  316. err = database.Insert(k, string(data), database.DNS_TABLE_NAME)
  317. return entry, err
  318. }
  319. func ValidateNameserverReq(ns schema.Nameserver) error {
  320. if ns.Name == "" {
  321. return errors.New("name is required")
  322. }
  323. if ns.NetworkID == "" {
  324. return errors.New("network is required")
  325. }
  326. if len(ns.Servers) == 0 {
  327. return errors.New("atleast one nameserver should be specified")
  328. }
  329. if !ns.MatchAll && len(ns.MatchDomains) == 0 {
  330. return errors.New("atleast one match domain is required")
  331. }
  332. if !ns.MatchAll {
  333. for _, matchDomain := range ns.MatchDomains {
  334. if !IsValidMatchDomain(matchDomain) {
  335. return errors.New("invalid match domain")
  336. }
  337. }
  338. }
  339. return nil
  340. }
  341. func GetNameserversForNode(node *models.Node) (returnNsLi []models.Nameserver) {
  342. ns := &schema.Nameserver{
  343. NetworkID: node.Network,
  344. }
  345. nsLi, _ := ns.ListByNetwork(db.WithContext(context.TODO()))
  346. for _, nsI := range nsLi {
  347. if !nsI.Status {
  348. continue
  349. }
  350. _, all := nsI.Tags["*"]
  351. if all {
  352. for _, matchDomain := range nsI.MatchDomains {
  353. returnNsLi = append(returnNsLi, models.Nameserver{
  354. IPs: nsI.Servers,
  355. MatchDomain: matchDomain,
  356. })
  357. }
  358. continue
  359. }
  360. for tagI := range node.Tags {
  361. if _, ok := nsI.Tags[tagI.String()]; ok {
  362. for _, matchDomain := range nsI.MatchDomains {
  363. returnNsLi = append(returnNsLi, models.Nameserver{
  364. IPs: nsI.Servers,
  365. MatchDomain: matchDomain,
  366. })
  367. }
  368. }
  369. }
  370. }
  371. if node.IsInternetGateway {
  372. globalNs := models.Nameserver{
  373. MatchDomain: ".",
  374. }
  375. for _, nsI := range GlobalNsList {
  376. globalNs.IPs = append(globalNs.IPs, nsI.IPs...)
  377. }
  378. returnNsLi = append(returnNsLi, globalNs)
  379. }
  380. return
  381. }
  382. func GetNameserversForHost(h *models.Host) (returnNsLi []models.Nameserver) {
  383. if h.DNS != "yes" {
  384. return
  385. }
  386. for _, nodeID := range h.Nodes {
  387. node, err := GetNodeByID(nodeID)
  388. if err != nil {
  389. continue
  390. }
  391. ns := &schema.Nameserver{
  392. NetworkID: node.Network,
  393. }
  394. nsLi, _ := ns.ListByNetwork(db.WithContext(context.TODO()))
  395. for _, nsI := range nsLi {
  396. if !nsI.Status {
  397. continue
  398. }
  399. _, all := nsI.Tags["*"]
  400. if all {
  401. for _, matchDomain := range nsI.MatchDomains {
  402. returnNsLi = append(returnNsLi, models.Nameserver{
  403. IPs: nsI.Servers,
  404. MatchDomain: matchDomain,
  405. })
  406. }
  407. continue
  408. }
  409. for tagI := range node.Tags {
  410. if _, ok := nsI.Tags[tagI.String()]; ok {
  411. for _, matchDomain := range nsI.MatchDomains {
  412. returnNsLi = append(returnNsLi, models.Nameserver{
  413. IPs: nsI.Servers,
  414. MatchDomain: matchDomain,
  415. })
  416. }
  417. }
  418. }
  419. }
  420. if node.IsInternetGateway {
  421. globalNs := models.Nameserver{
  422. MatchDomain: ".",
  423. }
  424. for _, nsI := range GlobalNsList {
  425. globalNs.IPs = append(globalNs.IPs, nsI.IPs...)
  426. }
  427. returnNsLi = append(returnNsLi, globalNs)
  428. }
  429. }
  430. return
  431. }
  432. // IsValidMatchDomain reports whether s is a valid "match domain".
  433. // Rules (simple/ASCII):
  434. // - "~." is allowed (match all).
  435. // - Optional leading "~" allowed (e.g., "~example.com").
  436. // - Optional single trailing "." allowed (FQDN form).
  437. // - No wildcards "*", no leading ".", no underscores.
  438. // - Labels: letters/digits/hyphen (LDH), 1–63 chars, no leading/trailing hyphen.
  439. // - Total length (without trailing dot) ≤ 253.
  440. func IsValidMatchDomain(s string) bool {
  441. s = strings.TrimSpace(s)
  442. if s == "" {
  443. return false
  444. }
  445. if s == "~." { // special case: match-all
  446. return true
  447. }
  448. // Strip optional leading "~"
  449. if strings.HasPrefix(s, "~") {
  450. s = s[1:]
  451. if s == "" {
  452. return false
  453. }
  454. }
  455. // Allow exactly one trailing dot
  456. if strings.HasSuffix(s, ".") {
  457. s = s[:len(s)-1]
  458. if s == "" {
  459. return false
  460. }
  461. }
  462. // Disallow leading dot, wildcards, underscores
  463. if strings.HasPrefix(s, ".") || strings.Contains(s, "*") || strings.Contains(s, "_") {
  464. return false
  465. }
  466. // Lowercase for ASCII checks
  467. s = strings.ToLower(s)
  468. // Length check
  469. if len(s) > 253 {
  470. return false
  471. }
  472. // Label regex: LDH, 1–63, no leading/trailing hyphen
  473. reLabel := regexp.MustCompile(`^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?$`)
  474. parts := strings.Split(s, ".")
  475. for _, lbl := range parts {
  476. if len(lbl) == 0 || len(lbl) > 63 {
  477. return false
  478. }
  479. if !reLabel.MatchString(lbl) {
  480. return false
  481. }
  482. }
  483. return true
  484. }