dns.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. package logic
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "net"
  8. "os"
  9. "regexp"
  10. "sort"
  11. "strings"
  12. validator "github.com/go-playground/validator/v10"
  13. "github.com/gravitl/netmaker/database"
  14. "github.com/gravitl/netmaker/db"
  15. "github.com/gravitl/netmaker/logger"
  16. "github.com/gravitl/netmaker/models"
  17. "github.com/gravitl/netmaker/schema"
  18. "github.com/txn2/txeh"
  19. )
  20. type GlobalNs struct {
  21. ID string `json:"id"`
  22. IPs []string `json:"ips"`
  23. }
  24. var GlobalNsList = map[string]GlobalNs{
  25. "Google": {
  26. ID: "Google",
  27. IPs: []string{
  28. "8.8.8.8",
  29. "8.8.4.4",
  30. "2001:4860:4860::8888",
  31. "2001:4860:4860::8844",
  32. },
  33. },
  34. "Cloudflare": {
  35. ID: "Cloudflare",
  36. IPs: []string{
  37. "1.1.1.1",
  38. "1.0.0.1",
  39. "2606:4700:4700::1111",
  40. "2606:4700:4700::1001",
  41. },
  42. },
  43. "Quad9": {
  44. ID: "Quad9",
  45. IPs: []string{
  46. "9.9.9.9",
  47. "149.112.112.112",
  48. "2620:fe::fe",
  49. "2620:fe::9",
  50. },
  51. },
  52. }
  53. // SetDNS - sets the dns on file
  54. func SetDNS() error {
  55. hostfile, err := txeh.NewHosts(&txeh.HostsConfig{})
  56. if err != nil {
  57. return err
  58. }
  59. var corefilestring string
  60. networks, err := GetNetworks()
  61. if err != nil && !database.IsEmptyRecord(err) {
  62. return err
  63. }
  64. for _, net := range networks {
  65. corefilestring = corefilestring + net.NetID + " "
  66. dns, err := GetDNS(net.NetID)
  67. if err != nil && !database.IsEmptyRecord(err) {
  68. return err
  69. }
  70. for _, entry := range dns {
  71. hostfile.AddHost(entry.Address, entry.Name)
  72. }
  73. }
  74. dns := GetExtclientDNS()
  75. for _, entry := range dns {
  76. hostfile.AddHost(entry.Address, entry.Name)
  77. }
  78. if corefilestring == "" {
  79. corefilestring = "example.com"
  80. }
  81. err = hostfile.SaveAs("./config/dnsconfig/netmaker.hosts")
  82. if err != nil {
  83. return err
  84. }
  85. /* if something goes wrong with server DNS, check here
  86. // commented out bc we were not using IsSplitDNS
  87. if servercfg.IsSplitDNS() {
  88. err = SetCorefile(corefilestring)
  89. }
  90. */
  91. return err
  92. }
  93. // GetDNS - gets the DNS of a current network
  94. func GetDNS(network string) ([]models.DNSEntry, error) {
  95. dns, err := GetNodeDNS(network)
  96. if err != nil && !database.IsEmptyRecord(err) {
  97. return dns, err
  98. }
  99. customdns, err := GetCustomDNS(network)
  100. if err != nil && !database.IsEmptyRecord(err) {
  101. return dns, err
  102. }
  103. dns = append(dns, customdns...)
  104. return dns, nil
  105. }
  106. func EgressDNs(network string) (entries []models.DNSEntry) {
  107. egs, _ := (&schema.Egress{
  108. Network: network,
  109. }).ListByNetwork(db.WithContext(context.TODO()))
  110. for _, egI := range egs {
  111. if egI.Domain != "" && len(egI.DomainAns) > 0 {
  112. entry := models.DNSEntry{
  113. Name: egI.Domain,
  114. }
  115. for _, domainAns := range egI.DomainAns {
  116. ip, _, err := net.ParseCIDR(domainAns)
  117. if err == nil {
  118. if ip.To4() != nil {
  119. entry.Address = ip.String()
  120. } else {
  121. entry.Address6 = ip.String()
  122. }
  123. }
  124. }
  125. entries = append(entries, entry)
  126. }
  127. }
  128. return
  129. }
  130. // GetExtclientDNS - gets all extclients dns entries
  131. func GetExtclientDNS() []models.DNSEntry {
  132. extclients, err := GetAllExtClients()
  133. if err != nil {
  134. return []models.DNSEntry{}
  135. }
  136. var dns []models.DNSEntry
  137. for _, extclient := range extclients {
  138. var entry = models.DNSEntry{}
  139. entry.Name = fmt.Sprintf("%s.%s", extclient.ClientID, extclient.Network)
  140. entry.Network = extclient.Network
  141. if extclient.Address != "" {
  142. entry.Address = extclient.Address
  143. }
  144. if extclient.Address6 != "" {
  145. entry.Address6 = extclient.Address6
  146. }
  147. dns = append(dns, entry)
  148. }
  149. return dns
  150. }
  151. // GetNodeDNS - gets the DNS of a network node
  152. func GetNodeDNS(network string) ([]models.DNSEntry, error) {
  153. var dns []models.DNSEntry
  154. nodes, err := GetNetworkNodes(network)
  155. if err != nil {
  156. return dns, err
  157. }
  158. defaultDomain := GetDefaultDomain()
  159. for _, node := range nodes {
  160. if node.Network != network {
  161. continue
  162. }
  163. host, err := GetHost(node.HostID.String())
  164. if err != nil {
  165. continue
  166. }
  167. var entry = models.DNSEntry{}
  168. if defaultDomain == "" {
  169. entry.Name = fmt.Sprintf("%s.%s", host.Name, network)
  170. } else {
  171. entry.Name = fmt.Sprintf("%s.%s.%s", host.Name, network, defaultDomain)
  172. }
  173. entry.Network = network
  174. if node.Address.IP != nil {
  175. entry.Address = node.Address.IP.String()
  176. }
  177. if node.Address6.IP != nil {
  178. entry.Address6 = node.Address6.IP.String()
  179. }
  180. dns = append(dns, entry)
  181. }
  182. return dns, nil
  183. }
  184. // GetCustomDNS - gets the custom DNS of a network
  185. func GetCustomDNS(network string) ([]models.DNSEntry, error) {
  186. var dns []models.DNSEntry
  187. collection, err := database.FetchRecords(database.DNS_TABLE_NAME)
  188. if err != nil {
  189. return dns, err
  190. }
  191. for _, value := range collection { // filter for entries based on network
  192. var entry models.DNSEntry
  193. if err := json.Unmarshal([]byte(value), &entry); err != nil {
  194. continue
  195. }
  196. if entry.Network == network {
  197. dns = append(dns, entry)
  198. }
  199. }
  200. return dns, err
  201. }
  202. // SetCorefile - sets the core file of the system
  203. func SetCorefile(domains string) error {
  204. dir, err := os.Getwd()
  205. if err != nil {
  206. return err
  207. }
  208. err = os.MkdirAll(dir+"/config/dnsconfig", 0744)
  209. if err != nil {
  210. logger.Log(0, "couldnt find or create /config/dnsconfig")
  211. return err
  212. }
  213. corefile := domains + ` {
  214. reload 15s
  215. hosts /root/dnsconfig/netmaker.hosts {
  216. fallthrough
  217. }
  218. forward . 8.8.8.8 8.8.4.4
  219. log
  220. }
  221. `
  222. err = os.WriteFile(dir+"/config/dnsconfig/Corefile", []byte(corefile), 0644)
  223. if err != nil {
  224. return err
  225. }
  226. return err
  227. }
  228. // GetAllDNS - gets all dns entries
  229. func GetAllDNS() ([]models.DNSEntry, error) {
  230. var dns []models.DNSEntry
  231. networks, err := GetNetworks()
  232. if err != nil && !database.IsEmptyRecord(err) {
  233. return []models.DNSEntry{}, err
  234. }
  235. for _, net := range networks {
  236. netdns, err := GetDNS(net.NetID)
  237. if err != nil {
  238. return []models.DNSEntry{}, nil
  239. }
  240. dns = append(dns, netdns...)
  241. }
  242. return dns, nil
  243. }
  244. // GetDNSEntryNum - gets which entry the dns was
  245. func GetDNSEntryNum(domain string, network string) (int, error) {
  246. num := 0
  247. entries, err := GetDNS(network)
  248. if err != nil {
  249. return 0, err
  250. }
  251. for i := 0; i < len(entries); i++ {
  252. if domain == entries[i].Name {
  253. num++
  254. }
  255. }
  256. return num, nil
  257. }
  258. // SortDNSEntrys - Sorts slice of DNSEnteys by their Address alphabetically with numbers first
  259. func SortDNSEntrys(unsortedDNSEntrys []models.DNSEntry) {
  260. sort.Slice(unsortedDNSEntrys, func(i, j int) bool {
  261. return unsortedDNSEntrys[i].Address < unsortedDNSEntrys[j].Address
  262. })
  263. }
  264. // IsNetworkNameValid - checks if a netid of a network uses valid characters
  265. func IsDNSEntryValid(d string) bool {
  266. re := regexp.MustCompile(`^[A-Za-z0-9-.]+$`)
  267. return re.MatchString(d)
  268. }
  269. // ValidateDNSCreate - checks if an entry is valid
  270. func ValidateDNSCreate(entry models.DNSEntry) error {
  271. if !IsDNSEntryValid(entry.Name) {
  272. return errors.New("invalid input. Only uppercase letters (A-Z), lowercase letters (a-z), numbers (0-9), minus sign (-) and dots (.) are allowed")
  273. }
  274. v := validator.New()
  275. _ = v.RegisterValidation("whitespace", func(f1 validator.FieldLevel) bool {
  276. match, _ := regexp.MatchString(`\s`, entry.Name)
  277. return !match
  278. })
  279. _ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
  280. num, err := GetDNSEntryNum(entry.Name, entry.Network)
  281. return err == nil && num == 0
  282. })
  283. _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
  284. _, err := GetParentNetwork(entry.Network)
  285. return err == nil
  286. })
  287. err := v.Struct(entry)
  288. if err != nil {
  289. for _, e := range err.(validator.ValidationErrors) {
  290. logger.Log(1, e.Error())
  291. }
  292. }
  293. return err
  294. }
  295. // ValidateDNSUpdate - validates a DNS update
  296. func ValidateDNSUpdate(change models.DNSEntry, entry models.DNSEntry) error {
  297. v := validator.New()
  298. _ = v.RegisterValidation("whitespace", func(f1 validator.FieldLevel) bool {
  299. match, _ := regexp.MatchString(`\s`, entry.Name)
  300. return !match
  301. })
  302. _ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
  303. //if name & net not changing name we are good
  304. if change.Name == entry.Name && change.Network == entry.Network {
  305. return true
  306. }
  307. num, err := GetDNSEntryNum(change.Name, change.Network)
  308. return err == nil && num == 0
  309. })
  310. _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
  311. _, err := GetParentNetwork(change.Network)
  312. return err == nil
  313. })
  314. err := v.Struct(change)
  315. if err != nil {
  316. for _, e := range err.(validator.ValidationErrors) {
  317. logger.Log(1, e.Error())
  318. }
  319. }
  320. return err
  321. }
  322. // DeleteDNS - deletes a DNS entry
  323. func DeleteDNS(domain string, network string) error {
  324. key, err := GetRecordKey(domain, network)
  325. if err != nil {
  326. return err
  327. }
  328. err = database.DeleteRecord(database.DNS_TABLE_NAME, key)
  329. return err
  330. }
  331. // CreateDNS - creates a DNS entry
  332. func CreateDNS(entry models.DNSEntry) (models.DNSEntry, error) {
  333. k, err := GetRecordKey(entry.Name, entry.Network)
  334. if err != nil {
  335. return models.DNSEntry{}, err
  336. }
  337. data, err := json.Marshal(&entry)
  338. if err != nil {
  339. return models.DNSEntry{}, err
  340. }
  341. err = database.Insert(k, string(data), database.DNS_TABLE_NAME)
  342. return entry, err
  343. }
  344. func ValidateNameserverReq(ns schema.Nameserver) error {
  345. if ns.Name == "" {
  346. return errors.New("name is required")
  347. }
  348. if ns.NetworkID == "" {
  349. return errors.New("network is required")
  350. }
  351. if len(ns.Servers) == 0 {
  352. return errors.New("atleast one nameserver should be specified")
  353. }
  354. if !IsValidMatchDomain(ns.MatchDomain) {
  355. return errors.New("invalid match domain")
  356. }
  357. return nil
  358. }
  359. func ValidateUpdateNameserverReq(updateNs schema.Nameserver) error {
  360. if updateNs.Name == "" {
  361. return errors.New("name is required")
  362. }
  363. if len(updateNs.Servers) == 0 {
  364. return errors.New("atleast one nameserver should be specified")
  365. }
  366. if !IsValidMatchDomain(updateNs.MatchDomain) {
  367. return errors.New("invalid match domain")
  368. }
  369. return nil
  370. }
  371. func GetNameserversForHost(h *models.Host) (returnNsLi []models.Nameserver) {
  372. if h.DNS != "yes" {
  373. return
  374. }
  375. for _, nodeID := range h.Nodes {
  376. node, err := GetNodeByID(nodeID)
  377. if err != nil {
  378. continue
  379. }
  380. ns := &schema.Nameserver{
  381. NetworkID: node.Network,
  382. }
  383. nsLi, _ := ns.ListByNetwork(db.WithContext(context.TODO()))
  384. for _, nsI := range nsLi {
  385. if !nsI.Status {
  386. continue
  387. }
  388. _, all := nsI.Tags["*"]
  389. if all {
  390. returnNsLi = append(returnNsLi, models.Nameserver{
  391. IPs: nsI.Servers,
  392. MatchDomain: nsI.MatchDomain,
  393. })
  394. continue
  395. }
  396. for tagI := range node.Tags {
  397. if _, ok := nsI.Tags[tagI.String()]; ok {
  398. returnNsLi = append(returnNsLi, models.Nameserver{
  399. IPs: nsI.Servers,
  400. MatchDomain: nsI.MatchDomain,
  401. })
  402. }
  403. }
  404. }
  405. if node.IsInternetGateway {
  406. globalNs := models.Nameserver{
  407. MatchDomain: ".",
  408. }
  409. for _, nsI := range GlobalNsList {
  410. globalNs.IPs = append(globalNs.IPs, nsI.IPs...)
  411. }
  412. returnNsLi = append(returnNsLi, globalNs)
  413. }
  414. // add egress domains to list
  415. e := schema.Egress{
  416. Network: node.Network,
  417. }
  418. egs, _ := e.ListByNetwork(db.WithContext(context.TODO()))
  419. for _, egI := range egs {
  420. if egI.Domain != "" && len(egI.DomainAns) > 0 {
  421. egressNs := models.Nameserver{
  422. MatchDomain: egI.Domain,
  423. }
  424. if node.Address.IP != nil {
  425. egressNs.IPs = append(egressNs.IPs, node.Address.IP.String())
  426. }
  427. if node.Address6.IP != nil {
  428. egressNs.IPs = append(egressNs.IPs, node.Address6.IP.String())
  429. }
  430. returnNsLi = append(returnNsLi, egressNs)
  431. }
  432. }
  433. }
  434. return
  435. }
  436. // IsValidMatchDomain reports whether s is a valid "match domain".
  437. // Rules (simple/ASCII):
  438. // - "~." is allowed (match all).
  439. // - Optional leading "~" allowed (e.g., "~example.com").
  440. // - Optional single trailing "." allowed (FQDN form).
  441. // - No wildcards "*", no leading ".", no underscores.
  442. // - Labels: letters/digits/hyphen (LDH), 1–63 chars, no leading/trailing hyphen.
  443. // - Total length (without trailing dot) ≤ 253.
  444. func IsValidMatchDomain(s string) bool {
  445. s = strings.TrimSpace(s)
  446. if s == "" {
  447. return false
  448. }
  449. if s == "~." { // special case: match-all
  450. return true
  451. }
  452. // Strip optional leading "~"
  453. if strings.HasPrefix(s, "~") {
  454. s = s[1:]
  455. if s == "" {
  456. return false
  457. }
  458. }
  459. // Allow exactly one trailing dot
  460. if strings.HasSuffix(s, ".") {
  461. s = s[:len(s)-1]
  462. if s == "" {
  463. return false
  464. }
  465. }
  466. // Disallow leading dot, wildcards, underscores
  467. if strings.HasPrefix(s, ".") || strings.Contains(s, "*") || strings.Contains(s, "_") {
  468. return false
  469. }
  470. // Lowercase for ASCII checks
  471. s = strings.ToLower(s)
  472. // Length check
  473. if len(s) > 253 {
  474. return false
  475. }
  476. // Label regex: LDH, 1–63, no leading/trailing hyphen
  477. reLabel := regexp.MustCompile(`^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?$`)
  478. parts := strings.Split(s, ".")
  479. for _, lbl := range parts {
  480. if len(lbl) == 0 || len(lbl) > 63 {
  481. return false
  482. }
  483. if !reLabel.MatchString(lbl) {
  484. return false
  485. }
  486. }
  487. return true
  488. }