dns.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. package logic
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "os"
  8. "regexp"
  9. "sort"
  10. "strings"
  11. validator "github.com/go-playground/validator/v10"
  12. "github.com/gravitl/netmaker/database"
  13. "github.com/gravitl/netmaker/db"
  14. "github.com/gravitl/netmaker/logger"
  15. "github.com/gravitl/netmaker/models"
  16. "github.com/gravitl/netmaker/schema"
  17. "github.com/txn2/txeh"
  18. )
  19. // SetDNS - sets the dns on file
  20. func SetDNS() error {
  21. hostfile, err := txeh.NewHosts(&txeh.HostsConfig{})
  22. if err != nil {
  23. return err
  24. }
  25. var corefilestring string
  26. networks, err := GetNetworks()
  27. if err != nil && !database.IsEmptyRecord(err) {
  28. return err
  29. }
  30. for _, net := range networks {
  31. corefilestring = corefilestring + net.NetID + " "
  32. dns, err := GetDNS(net.NetID)
  33. if err != nil && !database.IsEmptyRecord(err) {
  34. return err
  35. }
  36. for _, entry := range dns {
  37. hostfile.AddHost(entry.Address, entry.Name)
  38. }
  39. }
  40. dns := GetExtclientDNS()
  41. for _, entry := range dns {
  42. hostfile.AddHost(entry.Address, entry.Name)
  43. }
  44. if corefilestring == "" {
  45. corefilestring = "example.com"
  46. }
  47. err = hostfile.SaveAs("./config/dnsconfig/netmaker.hosts")
  48. if err != nil {
  49. return err
  50. }
  51. /* if something goes wrong with server DNS, check here
  52. // commented out bc we were not using IsSplitDNS
  53. if servercfg.IsSplitDNS() {
  54. err = SetCorefile(corefilestring)
  55. }
  56. */
  57. return err
  58. }
  59. // GetDNS - gets the DNS of a current network
  60. func GetDNS(network string) ([]models.DNSEntry, error) {
  61. dns, err := GetNodeDNS(network)
  62. if err != nil && !database.IsEmptyRecord(err) {
  63. return dns, err
  64. }
  65. customdns, err := GetCustomDNS(network)
  66. if err != nil && !database.IsEmptyRecord(err) {
  67. return dns, err
  68. }
  69. dns = append(dns, customdns...)
  70. return dns, nil
  71. }
  72. // GetExtclientDNS - gets all extclients dns entries
  73. func GetExtclientDNS() []models.DNSEntry {
  74. extclients, err := GetAllExtClients()
  75. if err != nil {
  76. return []models.DNSEntry{}
  77. }
  78. var dns []models.DNSEntry
  79. for _, extclient := range extclients {
  80. var entry = models.DNSEntry{}
  81. entry.Name = fmt.Sprintf("%s.%s", extclient.ClientID, extclient.Network)
  82. entry.Network = extclient.Network
  83. if extclient.Address != "" {
  84. entry.Address = extclient.Address
  85. }
  86. if extclient.Address6 != "" {
  87. entry.Address6 = extclient.Address6
  88. }
  89. dns = append(dns, entry)
  90. }
  91. return dns
  92. }
  93. // GetNodeDNS - gets the DNS of a network node
  94. func GetNodeDNS(network string) ([]models.DNSEntry, error) {
  95. var dns []models.DNSEntry
  96. nodes, err := GetNetworkNodes(network)
  97. if err != nil {
  98. return dns, err
  99. }
  100. defaultDomain := GetDefaultDomain()
  101. for _, node := range nodes {
  102. if node.Network != network {
  103. continue
  104. }
  105. host, err := GetHost(node.HostID.String())
  106. if err != nil {
  107. continue
  108. }
  109. var entry = models.DNSEntry{}
  110. if defaultDomain == "" {
  111. entry.Name = fmt.Sprintf("%s.%s", host.Name, network)
  112. } else {
  113. entry.Name = fmt.Sprintf("%s.%s.%s", host.Name, network, defaultDomain)
  114. }
  115. entry.Network = network
  116. if node.Address.IP != nil {
  117. entry.Address = node.Address.IP.String()
  118. }
  119. if node.Address6.IP != nil {
  120. entry.Address6 = node.Address6.IP.String()
  121. }
  122. dns = append(dns, entry)
  123. }
  124. return dns, nil
  125. }
  126. // GetCustomDNS - gets the custom DNS of a network
  127. func GetCustomDNS(network string) ([]models.DNSEntry, error) {
  128. var dns []models.DNSEntry
  129. collection, err := database.FetchRecords(database.DNS_TABLE_NAME)
  130. if err != nil {
  131. return dns, err
  132. }
  133. for _, value := range collection { // filter for entries based on network
  134. var entry models.DNSEntry
  135. if err := json.Unmarshal([]byte(value), &entry); err != nil {
  136. continue
  137. }
  138. if entry.Network == network {
  139. dns = append(dns, entry)
  140. }
  141. }
  142. return dns, err
  143. }
  144. // SetCorefile - sets the core file of the system
  145. func SetCorefile(domains string) error {
  146. dir, err := os.Getwd()
  147. if err != nil {
  148. return err
  149. }
  150. err = os.MkdirAll(dir+"/config/dnsconfig", 0744)
  151. if err != nil {
  152. logger.Log(0, "couldnt find or create /config/dnsconfig")
  153. return err
  154. }
  155. corefile := domains + ` {
  156. reload 15s
  157. hosts /root/dnsconfig/netmaker.hosts {
  158. fallthrough
  159. }
  160. forward . 8.8.8.8 8.8.4.4
  161. log
  162. }
  163. `
  164. err = os.WriteFile(dir+"/config/dnsconfig/Corefile", []byte(corefile), 0644)
  165. if err != nil {
  166. return err
  167. }
  168. return err
  169. }
  170. // GetAllDNS - gets all dns entries
  171. func GetAllDNS() ([]models.DNSEntry, error) {
  172. var dns []models.DNSEntry
  173. networks, err := GetNetworks()
  174. if err != nil && !database.IsEmptyRecord(err) {
  175. return []models.DNSEntry{}, err
  176. }
  177. for _, net := range networks {
  178. netdns, err := GetDNS(net.NetID)
  179. if err != nil {
  180. return []models.DNSEntry{}, nil
  181. }
  182. dns = append(dns, netdns...)
  183. }
  184. return dns, nil
  185. }
  186. // GetDNSEntryNum - gets which entry the dns was
  187. func GetDNSEntryNum(domain string, network string) (int, error) {
  188. num := 0
  189. entries, err := GetDNS(network)
  190. if err != nil {
  191. return 0, err
  192. }
  193. for i := 0; i < len(entries); i++ {
  194. if domain == entries[i].Name {
  195. num++
  196. }
  197. }
  198. return num, nil
  199. }
  200. // SortDNSEntrys - Sorts slice of DNSEnteys by their Address alphabetically with numbers first
  201. func SortDNSEntrys(unsortedDNSEntrys []models.DNSEntry) {
  202. sort.Slice(unsortedDNSEntrys, func(i, j int) bool {
  203. return unsortedDNSEntrys[i].Address < unsortedDNSEntrys[j].Address
  204. })
  205. }
  206. // IsNetworkNameValid - checks if a netid of a network uses valid characters
  207. func IsDNSEntryValid(d string) bool {
  208. re := regexp.MustCompile(`^[A-Za-z0-9-.]+$`)
  209. return re.MatchString(d)
  210. }
  211. // ValidateDNSCreate - checks if an entry is valid
  212. func ValidateDNSCreate(entry models.DNSEntry) error {
  213. if !IsDNSEntryValid(entry.Name) {
  214. return errors.New("invalid input. Only uppercase letters (A-Z), lowercase letters (a-z), numbers (0-9), minus sign (-) and dots (.) are allowed")
  215. }
  216. v := validator.New()
  217. _ = v.RegisterValidation("whitespace", func(f1 validator.FieldLevel) bool {
  218. match, _ := regexp.MatchString(`\s`, entry.Name)
  219. return !match
  220. })
  221. _ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
  222. num, err := GetDNSEntryNum(entry.Name, entry.Network)
  223. return err == nil && num == 0
  224. })
  225. _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
  226. _, err := GetParentNetwork(entry.Network)
  227. return err == nil
  228. })
  229. err := v.Struct(entry)
  230. if err != nil {
  231. for _, e := range err.(validator.ValidationErrors) {
  232. logger.Log(1, e.Error())
  233. }
  234. }
  235. return err
  236. }
  237. // ValidateDNSUpdate - validates a DNS update
  238. func ValidateDNSUpdate(change models.DNSEntry, entry models.DNSEntry) error {
  239. v := validator.New()
  240. _ = v.RegisterValidation("whitespace", func(f1 validator.FieldLevel) bool {
  241. match, _ := regexp.MatchString(`\s`, entry.Name)
  242. return !match
  243. })
  244. _ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
  245. //if name & net not changing name we are good
  246. if change.Name == entry.Name && change.Network == entry.Network {
  247. return true
  248. }
  249. num, err := GetDNSEntryNum(change.Name, change.Network)
  250. return err == nil && num == 0
  251. })
  252. _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
  253. _, err := GetParentNetwork(change.Network)
  254. return err == nil
  255. })
  256. err := v.Struct(change)
  257. if err != nil {
  258. for _, e := range err.(validator.ValidationErrors) {
  259. logger.Log(1, e.Error())
  260. }
  261. }
  262. return err
  263. }
  264. // DeleteDNS - deletes a DNS entry
  265. func DeleteDNS(domain string, network string) error {
  266. key, err := GetRecordKey(domain, network)
  267. if err != nil {
  268. return err
  269. }
  270. err = database.DeleteRecord(database.DNS_TABLE_NAME, key)
  271. return err
  272. }
  273. // CreateDNS - creates a DNS entry
  274. func CreateDNS(entry models.DNSEntry) (models.DNSEntry, error) {
  275. k, err := GetRecordKey(entry.Name, entry.Network)
  276. if err != nil {
  277. return models.DNSEntry{}, err
  278. }
  279. data, err := json.Marshal(&entry)
  280. if err != nil {
  281. return models.DNSEntry{}, err
  282. }
  283. err = database.Insert(k, string(data), database.DNS_TABLE_NAME)
  284. return entry, err
  285. }
  286. func ValidateNameserverReq(ns schema.Nameserver) error {
  287. if ns.Name == "" {
  288. return errors.New("name is required")
  289. }
  290. if ns.NetworkID == "" {
  291. return errors.New("network is required")
  292. }
  293. if len(ns.Servers) == 0 {
  294. return errors.New("atleast one nameserver should be specified")
  295. }
  296. if !IsValidMatchDomain(ns.MatchDomain) {
  297. return errors.New("invalid match domain")
  298. }
  299. return nil
  300. }
  301. func ValidateUpdateNameserverReq(updateNs schema.Nameserver) error {
  302. if updateNs.Name == "" {
  303. return errors.New("name is required")
  304. }
  305. if len(updateNs.Servers) == 0 {
  306. return errors.New("atleast one nameserver should be specified")
  307. }
  308. if !IsValidMatchDomain(updateNs.MatchDomain) {
  309. return errors.New("invalid match domain")
  310. }
  311. return nil
  312. }
  313. func GetNameserversForHost(h *models.Host) (returnNsLi []models.Nameserver) {
  314. if h.DNS != "yes" {
  315. return
  316. }
  317. for _, nodeID := range h.Nodes {
  318. node, err := GetNodeByID(nodeID)
  319. if err != nil {
  320. continue
  321. }
  322. ns := &schema.Nameserver{
  323. NetworkID: node.Network,
  324. }
  325. nsLi, _ := ns.ListByNetwork(db.WithContext(context.TODO()))
  326. for _, nsI := range nsLi {
  327. if !nsI.Status {
  328. continue
  329. }
  330. _, all := nsI.Tags["*"]
  331. if all {
  332. returnNsLi = append(returnNsLi, models.Nameserver{
  333. IPs: nsI.Servers,
  334. MatchDomain: nsI.MatchDomain,
  335. })
  336. continue
  337. }
  338. for tagI := range node.Tags {
  339. if _, ok := nsI.Tags[tagI.String()]; ok {
  340. returnNsLi = append(returnNsLi, models.Nameserver{
  341. IPs: nsI.Servers,
  342. MatchDomain: nsI.MatchDomain,
  343. })
  344. }
  345. }
  346. }
  347. }
  348. return
  349. }
  350. // IsValidMatchDomain reports whether s is a valid "match domain".
  351. // Rules (simple/ASCII):
  352. // - "~." is allowed (match all).
  353. // - Optional leading "~" allowed (e.g., "~example.com").
  354. // - Optional single trailing "." allowed (FQDN form).
  355. // - No wildcards "*", no leading ".", no underscores.
  356. // - Labels: letters/digits/hyphen (LDH), 1–63 chars, no leading/trailing hyphen.
  357. // - Total length (without trailing dot) ≤ 253.
  358. func IsValidMatchDomain(s string) bool {
  359. s = strings.TrimSpace(s)
  360. if s == "" {
  361. return false
  362. }
  363. if s == "~." { // special case: match-all
  364. return true
  365. }
  366. // Strip optional leading "~"
  367. if strings.HasPrefix(s, "~") {
  368. s = s[1:]
  369. if s == "" {
  370. return false
  371. }
  372. }
  373. // Allow exactly one trailing dot
  374. if strings.HasSuffix(s, ".") {
  375. s = s[:len(s)-1]
  376. if s == "" {
  377. return false
  378. }
  379. }
  380. // Disallow leading dot, wildcards, underscores
  381. if strings.HasPrefix(s, ".") || strings.Contains(s, "*") || strings.Contains(s, "_") {
  382. return false
  383. }
  384. // Lowercase for ASCII checks
  385. s = strings.ToLower(s)
  386. // Length check
  387. if len(s) > 253 {
  388. return false
  389. }
  390. // Label regex: LDH, 1–63, no leading/trailing hyphen
  391. reLabel := regexp.MustCompile(`^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?$`)
  392. parts := strings.Split(s, ".")
  393. for _, lbl := range parts {
  394. if len(lbl) == 0 || len(lbl) > 63 {
  395. return false
  396. }
  397. if !reLabel.MatchString(lbl) {
  398. return false
  399. }
  400. }
  401. return true
  402. }