main.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. package nebula
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "net"
  6. "os"
  7. "os/signal"
  8. "strconv"
  9. "strings"
  10. "syscall"
  11. "time"
  12. "github.com/sirupsen/logrus"
  13. "github.com/slackhq/nebula/sshd"
  14. "gopkg.in/yaml.v2"
  15. )
  16. // The caller should provide a real logger, we have one just in case
  17. var l = logrus.New()
  18. type m map[string]interface{}
  19. type CommandRequest struct {
  20. Command string
  21. Callback chan error
  22. }
  23. func Main(config *Config, configTest bool, block bool, buildVersion string, logger *logrus.Logger, tunFd *int, commandChan <-chan CommandRequest) error {
  24. l = logger
  25. l.Formatter = &logrus.TextFormatter{
  26. FullTimestamp: true,
  27. }
  28. // Print the config if in test, the exit comes later
  29. if configTest {
  30. b, err := yaml.Marshal(config.Settings)
  31. if err != nil {
  32. return err
  33. }
  34. // Print the final config
  35. l.Println(string(b))
  36. }
  37. err := configLogger(config)
  38. if err != nil {
  39. return NewContextualError("Failed to configure the logger", nil, err)
  40. }
  41. config.RegisterReloadCallback(func(c *Config) {
  42. err := configLogger(c)
  43. if err != nil {
  44. l.WithError(err).Error("Failed to configure the logger")
  45. }
  46. })
  47. // trustedCAs is currently a global, so loadCA operates on that global directly
  48. trustedCAs, err = loadCAFromConfig(config)
  49. if err != nil {
  50. //The errors coming out of loadCA are already nicely formatted
  51. return NewContextualError("Failed to load ca from config", nil, err)
  52. }
  53. l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
  54. cs, err := NewCertStateFromConfig(config)
  55. if err != nil {
  56. //The errors coming out of NewCertStateFromConfig are already nicely formatted
  57. return NewContextualError("Failed to load certificate from config", nil, err)
  58. }
  59. l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
  60. fw, err := NewFirewallFromConfig(cs.certificate, config)
  61. if err != nil {
  62. return NewContextualError("Error while loading firewall rules", nil, err)
  63. }
  64. l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
  65. // TODO: make sure mask is 4 bytes
  66. tunCidr := cs.certificate.Details.Ips[0]
  67. routes, err := parseRoutes(config, tunCidr)
  68. if err != nil {
  69. return NewContextualError("Could not parse tun.routes", nil, err)
  70. }
  71. unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
  72. if err != nil {
  73. return NewContextualError("Could not parse tun.unsafe_routes", nil, err)
  74. }
  75. ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
  76. wireSSHReload(ssh, config)
  77. if config.GetBool("sshd.enabled", false) {
  78. err = configSSH(ssh, config)
  79. if err != nil {
  80. return NewContextualError("Error while configuring the sshd", nil, err)
  81. }
  82. }
  83. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  84. // All non system modifying configuration consumption should live above this line
  85. // tun config, listeners, anything modifying the computer should be below
  86. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  87. var tun *Tun
  88. if !configTest {
  89. config.CatchHUP()
  90. if tunFd != nil {
  91. tun, err = newTunFromFd(
  92. *tunFd,
  93. tunCidr,
  94. config.GetInt("tun.mtu", DEFAULT_MTU),
  95. routes,
  96. unsafeRoutes,
  97. config.GetInt("tun.tx_queue", 500),
  98. )
  99. } else {
  100. tun, err = newTun(
  101. config.GetString("tun.dev", ""),
  102. tunCidr,
  103. config.GetInt("tun.mtu", DEFAULT_MTU),
  104. routes,
  105. unsafeRoutes,
  106. config.GetInt("tun.tx_queue", 500),
  107. )
  108. }
  109. if err != nil {
  110. return NewContextualError("Failed to get a tun/tap device", nil, err)
  111. }
  112. }
  113. // set up our UDP listener
  114. udpQueues := config.GetInt("listen.routines", 1)
  115. var udpServer *udpConn
  116. if !configTest {
  117. udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
  118. if err != nil {
  119. return NewContextualError("Failed to open udp listener", nil, err)
  120. }
  121. udpServer.reloadConfig(config)
  122. }
  123. sigChan := make(chan os.Signal)
  124. killChan := make(chan CommandRequest)
  125. if commandChan != nil {
  126. go func() {
  127. cmd := CommandRequest{}
  128. for {
  129. cmd = <-commandChan
  130. switch cmd.Command {
  131. case "rebind":
  132. udpServer.Rebind()
  133. case "exit":
  134. killChan <- cmd
  135. }
  136. }
  137. }()
  138. }
  139. // Set up my internal host map
  140. var preferredRanges []*net.IPNet
  141. rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{})
  142. // First, check if 'preferred_ranges' is set and fallback to 'local_range'
  143. if len(rawPreferredRanges) > 0 {
  144. for _, rawPreferredRange := range rawPreferredRanges {
  145. _, preferredRange, err := net.ParseCIDR(rawPreferredRange)
  146. if err != nil {
  147. return NewContextualError("Failed to parse preferred ranges", nil, err)
  148. }
  149. preferredRanges = append(preferredRanges, preferredRange)
  150. }
  151. }
  152. // local_range was superseded by preferred_ranges. If it is still present,
  153. // merge the local_range setting into preferred_ranges. We will probably
  154. // deprecate local_range and remove in the future.
  155. rawLocalRange := config.GetString("local_range", "")
  156. if rawLocalRange != "" {
  157. _, localRange, err := net.ParseCIDR(rawLocalRange)
  158. if err != nil {
  159. return NewContextualError("Failed to parse local_range", nil, err)
  160. }
  161. // Check if the entry for local_range was already specified in
  162. // preferred_ranges. Don't put it into the slice twice if so.
  163. var found bool
  164. for _, r := range preferredRanges {
  165. if r.String() == localRange.String() {
  166. found = true
  167. break
  168. }
  169. }
  170. if !found {
  171. preferredRanges = append(preferredRanges, localRange)
  172. }
  173. }
  174. hostMap := NewHostMap("main", tunCidr, preferredRanges)
  175. hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
  176. hostMap.addUnsafeRoutes(&unsafeRoutes)
  177. hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
  178. l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
  179. /*
  180. config.SetDefault("promoter.interval", 10)
  181. go hostMap.Promoter(config.GetInt("promoter.interval"))
  182. */
  183. punchy := NewPunchyFromConfig(config)
  184. if punchy.Punch && !configTest {
  185. l.Info("UDP hole punching enabled")
  186. go hostMap.Punchy(udpServer)
  187. }
  188. port := config.GetInt("listen.port", 0)
  189. // If port is dynamic, discover it
  190. if port == 0 && !configTest {
  191. uPort, err := udpServer.LocalAddr()
  192. if err != nil {
  193. return NewContextualError("Failed to get listening port", nil, err)
  194. }
  195. port = int(uPort.Port)
  196. }
  197. amLighthouse := config.GetBool("lighthouse.am_lighthouse", false)
  198. // warn if am_lighthouse is enabled but upstream lighthouses exists
  199. rawLighthouseHosts := config.GetStringSlice("lighthouse.hosts", []string{})
  200. if amLighthouse && len(rawLighthouseHosts) != 0 {
  201. l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
  202. }
  203. lighthouseHosts := make([]uint32, len(rawLighthouseHosts))
  204. for i, host := range rawLighthouseHosts {
  205. ip := net.ParseIP(host)
  206. if ip == nil {
  207. return NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
  208. }
  209. if !tunCidr.Contains(ip) {
  210. return NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
  211. }
  212. lighthouseHosts[i] = ip2int(ip)
  213. }
  214. lightHouse := NewLightHouse(
  215. amLighthouse,
  216. ip2int(tunCidr.IP),
  217. lighthouseHosts,
  218. //TODO: change to a duration
  219. config.GetInt("lighthouse.interval", 10),
  220. port,
  221. udpServer,
  222. punchy.Respond,
  223. punchy.Delay,
  224. config.GetBool("stats.lighthouse_metrics", false),
  225. )
  226. remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false)
  227. if err != nil {
  228. return NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
  229. }
  230. lightHouse.SetRemoteAllowList(remoteAllowList)
  231. localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true)
  232. if err != nil {
  233. return NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
  234. }
  235. lightHouse.SetLocalAllowList(localAllowList)
  236. //TODO: Move all of this inside functions in lighthouse.go
  237. for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
  238. vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
  239. if !tunCidr.Contains(vpnIp) {
  240. return NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
  241. }
  242. vals, ok := v.([]interface{})
  243. if ok {
  244. for _, v := range vals {
  245. parts := strings.Split(fmt.Sprintf("%v", v), ":")
  246. addr, err := net.ResolveIPAddr("ip", parts[0])
  247. if err == nil {
  248. ip := addr.IP
  249. port, err := strconv.Atoi(parts[1])
  250. if err != nil {
  251. return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
  252. }
  253. lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
  254. }
  255. }
  256. } else {
  257. //TODO: make this all a helper
  258. parts := strings.Split(fmt.Sprintf("%v", v), ":")
  259. addr, err := net.ResolveIPAddr("ip", parts[0])
  260. if err == nil {
  261. ip := addr.IP
  262. port, err := strconv.Atoi(parts[1])
  263. if err != nil {
  264. return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
  265. }
  266. lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
  267. }
  268. }
  269. }
  270. err = lightHouse.ValidateLHStaticEntries()
  271. if err != nil {
  272. l.WithError(err).Error("Lighthouse unreachable")
  273. }
  274. var messageMetrics *MessageMetrics
  275. if config.GetBool("stats.message_metrics", false) {
  276. messageMetrics = newMessageMetrics()
  277. } else {
  278. messageMetrics = newMessageMetricsOnlyRecvError()
  279. }
  280. handshakeConfig := HandshakeConfig{
  281. tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
  282. retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries),
  283. waitRotation: config.GetInt("handshakes.wait_rotation", DefaultHandshakeWaitRotation),
  284. messageMetrics: messageMetrics,
  285. }
  286. handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpServer, handshakeConfig)
  287. //TODO: These will be reused for psk
  288. //handshakeMACKey := config.GetString("handshake_mac.key", "")
  289. //handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{})
  290. serveDns := config.GetBool("lighthouse.serve_dns", false)
  291. checkInterval := config.GetInt("timers.connection_alive_interval", 5)
  292. pendingDeletionInterval := config.GetInt("timers.pending_deletion_interval", 10)
  293. ifConfig := &InterfaceConfig{
  294. HostMap: hostMap,
  295. Inside: tun,
  296. Outside: udpServer,
  297. certState: cs,
  298. Cipher: config.GetString("cipher", "aes"),
  299. Firewall: fw,
  300. ServeDns: serveDns,
  301. HandshakeManager: handshakeManager,
  302. lightHouse: lightHouse,
  303. checkInterval: checkInterval,
  304. pendingDeletionInterval: pendingDeletionInterval,
  305. DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false),
  306. DropMulticast: config.GetBool("tun.drop_multicast", false),
  307. UDPBatchSize: config.GetInt("listen.batch", 64),
  308. MessageMetrics: messageMetrics,
  309. }
  310. switch ifConfig.Cipher {
  311. case "aes":
  312. noiseEndianness = binary.BigEndian
  313. case "chachapoly":
  314. noiseEndianness = binary.LittleEndian
  315. default:
  316. return fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
  317. }
  318. var ifce *Interface
  319. if !configTest {
  320. ifce, err = NewInterface(ifConfig)
  321. if err != nil {
  322. return fmt.Errorf("failed to initialize interface: %s", err)
  323. }
  324. ifce.RegisterConfigChangeCallbacks(config)
  325. go handshakeManager.Run(ifce)
  326. go lightHouse.LhUpdateWorker(ifce)
  327. }
  328. err = startStats(config, configTest)
  329. if err != nil {
  330. return NewContextualError("Failed to start stats emitter", nil, err)
  331. }
  332. if configTest {
  333. return nil
  334. }
  335. //TODO: check if we _should_ be emitting stats
  336. go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
  337. attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
  338. ifce.Run(config.GetInt("tun.routines", 1), udpQueues, buildVersion)
  339. // Start DNS server last to allow using the nebula IP as lighthouse.dns.host
  340. if amLighthouse && serveDns {
  341. l.Debugln("Starting dns server")
  342. go dnsMain(hostMap, config)
  343. }
  344. if block {
  345. // Just sit here and be friendly, main thread.
  346. shutdownBlock(ifce, sigChan, killChan)
  347. } else {
  348. // Even though we aren't blocking we still want to shutdown gracefully
  349. go shutdownBlock(ifce, sigChan, killChan)
  350. }
  351. return nil
  352. }
  353. func shutdownBlock(ifce *Interface, sigChan chan os.Signal, killChan chan CommandRequest) {
  354. var cmd CommandRequest
  355. var sig string
  356. signal.Notify(sigChan, syscall.SIGTERM)
  357. signal.Notify(sigChan, syscall.SIGINT)
  358. select {
  359. case rawSig := <-sigChan:
  360. sig = rawSig.String()
  361. case cmd = <-killChan:
  362. sig = "controlling app"
  363. }
  364. l.WithField("signal", sig).Info("Caught signal, shutting down")
  365. //TODO: stop tun and udp routines, the lock on hostMap effectively does that though
  366. //TODO: this is probably better as a function in ConnectionManager or HostMap directly
  367. ifce.hostMap.Lock()
  368. for _, h := range ifce.hostMap.Hosts {
  369. if h.ConnectionState.ready {
  370. ifce.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
  371. l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
  372. Debug("Sending close tunnel message")
  373. }
  374. }
  375. ifce.hostMap.Unlock()
  376. l.WithField("signal", sig).Info("Goodbye")
  377. select {
  378. case cmd.Callback <- nil:
  379. default:
  380. }
  381. }