dns.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. package logic
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "os"
  8. "regexp"
  9. "sort"
  10. "strings"
  11. validator "github.com/go-playground/validator/v10"
  12. "github.com/gravitl/netmaker/database"
  13. "github.com/gravitl/netmaker/db"
  14. "github.com/gravitl/netmaker/logger"
  15. "github.com/gravitl/netmaker/models"
  16. "github.com/gravitl/netmaker/schema"
  17. "github.com/gravitl/netmaker/servercfg"
  18. "github.com/txn2/txeh"
  19. )
  20. var GetNameserversForNode = getNameserversForNode
  21. var GetNameserversForHost = getNameserversForHost
  22. var ValidateNameserverReq = validateNameserverReq
  23. type GlobalNs struct {
  24. ID string `json:"id"`
  25. IPs []string `json:"ips"`
  26. }
  27. var GlobalNsList = map[string]GlobalNs{
  28. "Google": {
  29. ID: "Google",
  30. IPs: []string{
  31. "8.8.8.8",
  32. "8.8.4.4",
  33. "2001:4860:4860::8888",
  34. "2001:4860:4860::8844",
  35. },
  36. },
  37. "Cloudflare": {
  38. ID: "Cloudflare",
  39. IPs: []string{
  40. "1.1.1.1",
  41. "1.0.0.1",
  42. "2606:4700:4700::1111",
  43. "2606:4700:4700::1001",
  44. },
  45. },
  46. "Quad9": {
  47. ID: "Quad9",
  48. IPs: []string{
  49. "9.9.9.9",
  50. "149.112.112.112",
  51. "2620:fe::fe",
  52. "2620:fe::9",
  53. },
  54. },
  55. }
  56. // SetDNS - sets the dns on file
  57. func SetDNS() error {
  58. hostfile, err := txeh.NewHosts(&txeh.HostsConfig{})
  59. if err != nil {
  60. return err
  61. }
  62. var corefilestring string
  63. networks, err := GetNetworks()
  64. if err != nil && !database.IsEmptyRecord(err) {
  65. return err
  66. }
  67. for _, net := range networks {
  68. corefilestring = corefilestring + net.NetID + " "
  69. dns, err := GetDNS(net.NetID)
  70. if err != nil && !database.IsEmptyRecord(err) {
  71. return err
  72. }
  73. for _, entry := range dns {
  74. hostfile.AddHost(entry.Address, entry.Name)
  75. }
  76. }
  77. dns := GetExtclientDNS()
  78. for _, entry := range dns {
  79. hostfile.AddHost(entry.Address, entry.Name)
  80. }
  81. if corefilestring == "" {
  82. corefilestring = "example.com"
  83. }
  84. err = hostfile.SaveAs("./config/dnsconfig/netmaker.hosts")
  85. if err != nil {
  86. return err
  87. }
  88. /* if something goes wrong with server DNS, check here
  89. // commented out bc we were not using IsSplitDNS
  90. if servercfg.IsSplitDNS() {
  91. err = SetCorefile(corefilestring)
  92. }
  93. */
  94. return err
  95. }
  96. // GetDNS - gets the DNS of a current network
  97. func GetDNS(network string) ([]models.DNSEntry, error) {
  98. dns, err := GetNodeDNS(network)
  99. if err != nil && !database.IsEmptyRecord(err) {
  100. return dns, err
  101. }
  102. customdns, err := GetCustomDNS(network)
  103. if err != nil && !database.IsEmptyRecord(err) {
  104. return dns, err
  105. }
  106. dns = append(dns, customdns...)
  107. return dns, nil
  108. }
  109. // GetExtclientDNS - gets all extclients dns entries
  110. func GetExtclientDNS() []models.DNSEntry {
  111. extclients, err := GetAllExtClients()
  112. if err != nil {
  113. return []models.DNSEntry{}
  114. }
  115. var dns []models.DNSEntry
  116. for _, extclient := range extclients {
  117. var entry = models.DNSEntry{}
  118. entry.Name = fmt.Sprintf("%s.%s", extclient.ClientID, extclient.Network)
  119. entry.Network = extclient.Network
  120. if extclient.Address != "" {
  121. entry.Address = extclient.Address
  122. }
  123. if extclient.Address6 != "" {
  124. entry.Address6 = extclient.Address6
  125. }
  126. dns = append(dns, entry)
  127. }
  128. return dns
  129. }
  130. // GetNodeDNS - gets the DNS of a network node
  131. func GetNodeDNS(network string) ([]models.DNSEntry, error) {
  132. var dns []models.DNSEntry
  133. nodes, err := GetNetworkNodes(network)
  134. if err != nil {
  135. return dns, err
  136. }
  137. defaultDomain := GetDefaultDomain()
  138. for _, node := range nodes {
  139. if node.Network != network {
  140. continue
  141. }
  142. host, err := GetHost(node.HostID.String())
  143. if err != nil {
  144. continue
  145. }
  146. var entry = models.DNSEntry{}
  147. if defaultDomain == "" {
  148. entry.Name = fmt.Sprintf("%s.%s", host.Name, network)
  149. } else {
  150. entry.Name = fmt.Sprintf("%s.%s.%s", host.Name, network, defaultDomain)
  151. }
  152. entry.Network = network
  153. if node.Address.IP != nil {
  154. entry.Address = node.Address.IP.String()
  155. }
  156. if node.Address6.IP != nil {
  157. entry.Address6 = node.Address6.IP.String()
  158. }
  159. dns = append(dns, entry)
  160. }
  161. return dns, nil
  162. }
  163. func GetGwDNS(node *models.Node) string {
  164. if !servercfg.GetManageDNS() {
  165. return ""
  166. }
  167. h, err := GetHost(node.HostID.String())
  168. if err != nil {
  169. return ""
  170. }
  171. if h.DNS != "yes" {
  172. return ""
  173. }
  174. dns := []string{}
  175. if node.Address.IP != nil {
  176. dns = append(dns, node.Address.IP.String())
  177. }
  178. if node.Address6.IP != nil {
  179. dns = append(dns, node.Address6.IP.String())
  180. }
  181. return strings.Join(dns, ",")
  182. }
  183. func SetDNSOnWgConfig(gwNode *models.Node, extclient *models.ExtClient) {
  184. if extclient.DNS == "" {
  185. extclient.DNS = GetGwDNS(gwNode)
  186. }
  187. }
  188. // GetCustomDNS - gets the custom DNS of a network
  189. func GetCustomDNS(network string) ([]models.DNSEntry, error) {
  190. var dns []models.DNSEntry
  191. collection, err := database.FetchRecords(database.DNS_TABLE_NAME)
  192. if err != nil {
  193. return dns, err
  194. }
  195. for _, value := range collection { // filter for entries based on network
  196. var entry models.DNSEntry
  197. if err := json.Unmarshal([]byte(value), &entry); err != nil {
  198. continue
  199. }
  200. if entry.Network == network {
  201. dns = append(dns, entry)
  202. }
  203. }
  204. return dns, err
  205. }
  206. // SetCorefile - sets the core file of the system
  207. func SetCorefile(domains string) error {
  208. dir, err := os.Getwd()
  209. if err != nil {
  210. return err
  211. }
  212. err = os.MkdirAll(dir+"/config/dnsconfig", 0744)
  213. if err != nil {
  214. logger.Log(0, "couldnt find or create /config/dnsconfig")
  215. return err
  216. }
  217. corefile := domains + ` {
  218. reload 15s
  219. hosts /root/dnsconfig/netmaker.hosts {
  220. fallthrough
  221. }
  222. forward . 8.8.8.8 8.8.4.4
  223. log
  224. }
  225. `
  226. err = os.WriteFile(dir+"/config/dnsconfig/Corefile", []byte(corefile), 0644)
  227. if err != nil {
  228. return err
  229. }
  230. return err
  231. }
  232. // GetAllDNS - gets all dns entries
  233. func GetAllDNS() ([]models.DNSEntry, error) {
  234. var dns []models.DNSEntry
  235. networks, err := GetNetworks()
  236. if err != nil && !database.IsEmptyRecord(err) {
  237. return []models.DNSEntry{}, err
  238. }
  239. for _, net := range networks {
  240. netdns, err := GetDNS(net.NetID)
  241. if err != nil {
  242. return []models.DNSEntry{}, nil
  243. }
  244. dns = append(dns, netdns...)
  245. }
  246. return dns, nil
  247. }
  248. // GetDNSEntryNum - gets which entry the dns was
  249. func GetDNSEntryNum(domain string, network string) (int, error) {
  250. num := 0
  251. entries, err := GetDNS(network)
  252. if err != nil {
  253. return 0, err
  254. }
  255. for i := 0; i < len(entries); i++ {
  256. if domain == entries[i].Name {
  257. num++
  258. }
  259. }
  260. return num, nil
  261. }
  262. // SortDNSEntrys - Sorts slice of DNSEnteys by their Address alphabetically with numbers first
  263. func SortDNSEntrys(unsortedDNSEntrys []models.DNSEntry) {
  264. sort.Slice(unsortedDNSEntrys, func(i, j int) bool {
  265. return unsortedDNSEntrys[i].Address < unsortedDNSEntrys[j].Address
  266. })
  267. }
  268. // IsNetworkNameValid - checks if a netid of a network uses valid characters
  269. func IsDNSEntryValid(d string) bool {
  270. re := regexp.MustCompile(`^[A-Za-z0-9-.]+$`)
  271. return re.MatchString(d)
  272. }
  273. // ValidateDNSCreate - checks if an entry is valid
  274. func ValidateDNSCreate(entry models.DNSEntry) error {
  275. if !IsDNSEntryValid(entry.Name) {
  276. return errors.New("invalid input. Only uppercase letters (A-Z), lowercase letters (a-z), numbers (0-9), minus sign (-) and dots (.) are allowed")
  277. }
  278. v := validator.New()
  279. _ = v.RegisterValidation("whitespace", func(f1 validator.FieldLevel) bool {
  280. match, _ := regexp.MatchString(`\s`, entry.Name)
  281. return !match
  282. })
  283. _ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
  284. num, err := GetDNSEntryNum(entry.Name, entry.Network)
  285. return err == nil && num == 0
  286. })
  287. _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
  288. _, err := GetParentNetwork(entry.Network)
  289. return err == nil
  290. })
  291. err := v.Struct(entry)
  292. if err != nil {
  293. for _, e := range err.(validator.ValidationErrors) {
  294. logger.Log(1, e.Error())
  295. }
  296. }
  297. return err
  298. }
  299. // ValidateDNSUpdate - validates a DNS update
  300. func ValidateDNSUpdate(change models.DNSEntry, entry models.DNSEntry) error {
  301. v := validator.New()
  302. _ = v.RegisterValidation("whitespace", func(f1 validator.FieldLevel) bool {
  303. match, _ := regexp.MatchString(`\s`, entry.Name)
  304. return !match
  305. })
  306. _ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
  307. //if name & net not changing name we are good
  308. if change.Name == entry.Name && change.Network == entry.Network {
  309. return true
  310. }
  311. num, err := GetDNSEntryNum(change.Name, change.Network)
  312. return err == nil && num == 0
  313. })
  314. _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
  315. _, err := GetParentNetwork(change.Network)
  316. return err == nil
  317. })
  318. err := v.Struct(change)
  319. if err != nil {
  320. for _, e := range err.(validator.ValidationErrors) {
  321. logger.Log(1, e.Error())
  322. }
  323. }
  324. return err
  325. }
  326. // DeleteDNS - deletes a DNS entry
  327. func DeleteDNS(domain string, network string) error {
  328. key, err := GetRecordKey(domain, network)
  329. if err != nil {
  330. return err
  331. }
  332. err = database.DeleteRecord(database.DNS_TABLE_NAME, key)
  333. return err
  334. }
  335. // CreateDNS - creates a DNS entry
  336. func CreateDNS(entry models.DNSEntry) (models.DNSEntry, error) {
  337. k, err := GetRecordKey(entry.Name, entry.Network)
  338. if err != nil {
  339. return models.DNSEntry{}, err
  340. }
  341. data, err := json.Marshal(&entry)
  342. if err != nil {
  343. return models.DNSEntry{}, err
  344. }
  345. err = database.Insert(k, string(data), database.DNS_TABLE_NAME)
  346. return entry, err
  347. }
  348. func validateNameserverReq(ns schema.Nameserver) error {
  349. if ns.Name == "" {
  350. return errors.New("name is required")
  351. }
  352. if ns.NetworkID == "" {
  353. return errors.New("network is required")
  354. }
  355. if len(ns.Servers) == 0 {
  356. return errors.New("atleast one nameserver should be specified")
  357. }
  358. if !ns.MatchAll && len(ns.MatchDomains) == 0 {
  359. return errors.New("atleast one match domain is required")
  360. }
  361. if !ns.MatchAll {
  362. for _, matchDomain := range ns.MatchDomains {
  363. if !IsValidMatchDomain(matchDomain) {
  364. return errors.New("invalid match domain")
  365. }
  366. }
  367. }
  368. // check if valid broadcast peers are added
  369. if len(ns.Nodes) > 0 {
  370. for nodeID := range ns.Nodes {
  371. node, err := GetNodeByID(nodeID)
  372. if err != nil {
  373. return errors.New("invalid node")
  374. }
  375. if node.Network != ns.NetworkID {
  376. return errors.New("invalid network node")
  377. }
  378. }
  379. }
  380. return nil
  381. }
  382. func getNameserversForNode(node *models.Node) (returnNsLi []models.Nameserver) {
  383. ns := &schema.Nameserver{
  384. NetworkID: node.Network,
  385. }
  386. nsLi, _ := ns.ListByNetwork(db.WithContext(context.TODO()))
  387. for _, nsI := range nsLi {
  388. if !nsI.Status {
  389. continue
  390. }
  391. _, all := nsI.Tags["*"]
  392. if all {
  393. for _, matchDomain := range nsI.MatchDomains {
  394. returnNsLi = append(returnNsLi, models.Nameserver{
  395. IPs: nsI.Servers,
  396. MatchDomain: matchDomain,
  397. })
  398. }
  399. continue
  400. }
  401. if _, ok := nsI.Nodes[node.ID.String()]; ok {
  402. for _, matchDomain := range nsI.MatchDomains {
  403. returnNsLi = append(returnNsLi, models.Nameserver{
  404. IPs: nsI.Servers,
  405. MatchDomain: matchDomain,
  406. })
  407. }
  408. }
  409. }
  410. if node.IsInternetGateway {
  411. globalNs := models.Nameserver{
  412. MatchDomain: ".",
  413. }
  414. for _, nsI := range GlobalNsList {
  415. globalNs.IPs = append(globalNs.IPs, nsI.IPs...)
  416. }
  417. returnNsLi = append(returnNsLi, globalNs)
  418. }
  419. return
  420. }
  421. func getNameserversForHost(h *models.Host) (returnNsLi []models.Nameserver) {
  422. if h.DNS != "yes" {
  423. return
  424. }
  425. for _, nodeID := range h.Nodes {
  426. node, err := GetNodeByID(nodeID)
  427. if err != nil {
  428. continue
  429. }
  430. ns := &schema.Nameserver{
  431. NetworkID: node.Network,
  432. }
  433. nsLi, _ := ns.ListByNetwork(db.WithContext(context.TODO()))
  434. for _, nsI := range nsLi {
  435. if !nsI.Status {
  436. continue
  437. }
  438. _, all := nsI.Tags["*"]
  439. if all {
  440. for _, matchDomain := range nsI.MatchDomains {
  441. returnNsLi = append(returnNsLi, models.Nameserver{
  442. IPs: nsI.Servers,
  443. MatchDomain: matchDomain,
  444. })
  445. }
  446. continue
  447. }
  448. if _, ok := nsI.Nodes[node.ID.String()]; ok {
  449. for _, matchDomain := range nsI.MatchDomains {
  450. returnNsLi = append(returnNsLi, models.Nameserver{
  451. IPs: nsI.Servers,
  452. MatchDomain: matchDomain,
  453. })
  454. }
  455. }
  456. }
  457. if node.IsInternetGateway {
  458. globalNs := models.Nameserver{
  459. MatchDomain: ".",
  460. }
  461. for _, nsI := range GlobalNsList {
  462. globalNs.IPs = append(globalNs.IPs, nsI.IPs...)
  463. }
  464. returnNsLi = append(returnNsLi, globalNs)
  465. }
  466. }
  467. return
  468. }
  469. // IsValidMatchDomain reports whether s is a valid "match domain".
  470. // Rules (simple/ASCII):
  471. // - "~." is allowed (match all).
  472. // - Optional leading "~" allowed (e.g., "~example.com").
  473. // - Optional single trailing "." allowed (FQDN form).
  474. // - No wildcards "*", no leading ".", no underscores.
  475. // - Labels: letters/digits/hyphen (LDH), 1–63 chars, no leading/trailing hyphen.
  476. // - Total length (without trailing dot) ≤ 253.
  477. func IsValidMatchDomain(s string) bool {
  478. s = strings.TrimSpace(s)
  479. if s == "" {
  480. return false
  481. }
  482. if s == "~." { // special case: match-all
  483. return true
  484. }
  485. // Strip optional leading "~"
  486. if strings.HasPrefix(s, "~") {
  487. s = s[1:]
  488. if s == "" {
  489. return false
  490. }
  491. }
  492. // Allow exactly one trailing dot
  493. if strings.HasSuffix(s, ".") {
  494. s = s[:len(s)-1]
  495. if s == "" {
  496. return false
  497. }
  498. }
  499. // Disallow leading dot, wildcards, underscores
  500. if strings.HasPrefix(s, ".") || strings.Contains(s, "*") || strings.Contains(s, "_") {
  501. return false
  502. }
  503. // Lowercase for ASCII checks
  504. s = strings.ToLower(s)
  505. // Length check
  506. if len(s) > 253 {
  507. return false
  508. }
  509. // Label regex: LDH, 1–63, no leading/trailing hyphen
  510. reLabel := regexp.MustCompile(`^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?$`)
  511. parts := strings.Split(s, ".")
  512. for _, lbl := range parts {
  513. if len(lbl) == 0 || len(lbl) > 63 {
  514. return false
  515. }
  516. if !reLabel.MatchString(lbl) {
  517. return false
  518. }
  519. }
  520. return true
  521. }