interface.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. package nebula
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/netip"
  8. "os"
  9. "runtime"
  10. "sync/atomic"
  11. "time"
  12. "github.com/gaissmai/bart"
  13. "github.com/rcrowley/go-metrics"
  14. "github.com/sirupsen/logrus"
  15. "github.com/slackhq/nebula/config"
  16. "github.com/slackhq/nebula/firewall"
  17. "github.com/slackhq/nebula/header"
  18. "github.com/slackhq/nebula/overlay"
  19. "github.com/slackhq/nebula/udp"
  20. )
  21. const mtu = 9001
  22. type InterfaceConfig struct {
  23. HostMap *HostMap
  24. Outside udp.Conn
  25. Inside overlay.Device
  26. pki *PKI
  27. Cipher string
  28. Firewall *Firewall
  29. ServeDns bool
  30. HandshakeManager *HandshakeManager
  31. lightHouse *LightHouse
  32. connectionManager *connectionManager
  33. DropLocalBroadcast bool
  34. DropMulticast bool
  35. routines int
  36. MessageMetrics *MessageMetrics
  37. version string
  38. relayManager *relayManager
  39. punchy *Punchy
  40. tryPromoteEvery uint32
  41. reQueryEvery uint32
  42. reQueryWait time.Duration
  43. ConntrackCacheTimeout time.Duration
  44. batchSize int
  45. l *logrus.Logger
  46. }
  47. type Interface struct {
  48. hostMap *HostMap
  49. outside udp.Conn
  50. inside overlay.Device
  51. pki *PKI
  52. firewall *Firewall
  53. connectionManager *connectionManager
  54. handshakeManager *HandshakeManager
  55. serveDns bool
  56. createTime time.Time
  57. lightHouse *LightHouse
  58. myBroadcastAddrsTable *bart.Lite
  59. myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
  60. myVpnAddrsTable *bart.Lite
  61. myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate
  62. myVpnNetworksTable *bart.Lite
  63. dropLocalBroadcast bool
  64. dropMulticast bool
  65. routines int
  66. disconnectInvalid atomic.Bool
  67. closed atomic.Bool
  68. relayManager *relayManager
  69. tryPromoteEvery atomic.Uint32
  70. reQueryEvery atomic.Uint32
  71. reQueryWait atomic.Int64
  72. sendRecvErrorConfig sendRecvErrorConfig
  73. // rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
  74. rebindCount int8
  75. version string
  76. conntrackCacheTimeout time.Duration
  77. batchSize int
  78. writers []udp.Conn
  79. readers []io.ReadWriteCloser
  80. metricHandshakes metrics.Histogram
  81. messageMetrics *MessageMetrics
  82. cachedPacketMetrics *cachedPacketMetrics
  83. l *logrus.Logger
  84. }
  85. type EncWriter interface {
  86. SendVia(via *HostInfo,
  87. relay *Relay,
  88. ad,
  89. nb,
  90. out []byte,
  91. nocopy bool,
  92. )
  93. SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte)
  94. SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
  95. Handshake(vpnAddr netip.Addr)
  96. GetHostInfo(vpnAddr netip.Addr) *HostInfo
  97. GetCertState() *CertState
  98. }
  99. // BatchReader is an interface for readers that support vectorized packet reading
  100. type BatchReader interface {
  101. BatchRead(buffers [][]byte, sizes []int) (int, error)
  102. }
  103. // BatchWriter is an interface for writers that support vectorized packet writing
  104. type BatchWriter interface {
  105. BatchWrite([][]byte) (int, error)
  106. }
  107. type sendRecvErrorConfig uint8
  108. const (
  109. sendRecvErrorAlways sendRecvErrorConfig = iota
  110. sendRecvErrorNever
  111. sendRecvErrorPrivate
  112. )
  113. func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool {
  114. switch s {
  115. case sendRecvErrorPrivate:
  116. return endpoint.Addr().IsPrivate()
  117. case sendRecvErrorAlways:
  118. return true
  119. case sendRecvErrorNever:
  120. return false
  121. default:
  122. panic(fmt.Errorf("invalid sendRecvErrorConfig value: %d", s))
  123. }
  124. }
  125. func (s sendRecvErrorConfig) String() string {
  126. switch s {
  127. case sendRecvErrorAlways:
  128. return "always"
  129. case sendRecvErrorNever:
  130. return "never"
  131. case sendRecvErrorPrivate:
  132. return "private"
  133. default:
  134. return fmt.Sprintf("invalid(%d)", s)
  135. }
  136. }
  137. func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
  138. if c.Outside == nil {
  139. return nil, errors.New("no outside connection")
  140. }
  141. if c.Inside == nil {
  142. return nil, errors.New("no inside interface (tun)")
  143. }
  144. if c.pki == nil {
  145. return nil, errors.New("no certificate state")
  146. }
  147. if c.Firewall == nil {
  148. return nil, errors.New("no firewall rules")
  149. }
  150. if c.connectionManager == nil {
  151. return nil, errors.New("no connection manager")
  152. }
  153. cs := c.pki.getCertState()
  154. ifce := &Interface{
  155. pki: c.pki,
  156. hostMap: c.HostMap,
  157. outside: c.Outside,
  158. inside: c.Inside,
  159. firewall: c.Firewall,
  160. serveDns: c.ServeDns,
  161. handshakeManager: c.HandshakeManager,
  162. createTime: time.Now(),
  163. lightHouse: c.lightHouse,
  164. dropLocalBroadcast: c.DropLocalBroadcast,
  165. dropMulticast: c.DropMulticast,
  166. routines: c.routines,
  167. version: c.version,
  168. writers: make([]udp.Conn, c.routines),
  169. readers: make([]io.ReadWriteCloser, c.routines),
  170. myVpnNetworks: cs.myVpnNetworks,
  171. myVpnNetworksTable: cs.myVpnNetworksTable,
  172. myVpnAddrs: cs.myVpnAddrs,
  173. myVpnAddrsTable: cs.myVpnAddrsTable,
  174. myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
  175. relayManager: c.relayManager,
  176. connectionManager: c.connectionManager,
  177. conntrackCacheTimeout: c.ConntrackCacheTimeout,
  178. batchSize: c.batchSize,
  179. metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
  180. messageMetrics: c.MessageMetrics,
  181. cachedPacketMetrics: &cachedPacketMetrics{
  182. sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil),
  183. dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
  184. },
  185. l: c.l,
  186. }
  187. ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
  188. ifce.reQueryEvery.Store(c.reQueryEvery)
  189. ifce.reQueryWait.Store(int64(c.reQueryWait))
  190. ifce.connectionManager.intf = ifce
  191. return ifce, nil
  192. }
  193. // activate creates the interface on the host. After the interface is created, any
  194. // other services that want to bind listeners to its IP may do so successfully. However,
  195. // the interface isn't going to process anything until run() is called.
  196. func (f *Interface) activate() {
  197. // actually turn on tun dev
  198. addr, err := f.outside.LocalAddr()
  199. if err != nil {
  200. f.l.WithError(err).Error("Failed to get udp listen address")
  201. }
  202. f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
  203. WithField("build", f.version).WithField("udpAddr", addr).
  204. WithField("boringcrypto", boringEnabled()).
  205. Info("Nebula interface is active")
  206. metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
  207. // Prepare n tun queues
  208. var reader io.ReadWriteCloser = f.inside
  209. for i := 0; i < f.routines; i++ {
  210. if i > 0 {
  211. reader, err = f.inside.NewMultiQueueReader()
  212. if err != nil {
  213. f.l.Fatal(err)
  214. }
  215. }
  216. f.readers[i] = reader
  217. }
  218. if err := f.inside.Activate(); err != nil {
  219. f.inside.Close()
  220. f.l.Fatal(err)
  221. }
  222. }
  223. func (f *Interface) run() {
  224. // Launch n queues to read packets from udp
  225. for i := 0; i < f.routines; i++ {
  226. go f.listenOut(i)
  227. }
  228. // Launch n queues to read packets from tun dev
  229. for i := 0; i < f.routines; i++ {
  230. go f.listenIn(f.readers[i], i)
  231. }
  232. }
  233. func (f *Interface) listenOut(i int) {
  234. runtime.LockOSThread()
  235. var li udp.Conn
  236. if i > 0 {
  237. li = f.writers[i]
  238. } else {
  239. li = f.outside
  240. }
  241. ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
  242. lhh := f.lightHouse.NewRequestHandler()
  243. plaintext := make([]byte, udp.MTU)
  244. h := &header.H{}
  245. fwPacket := &firewall.Packet{}
  246. nb := make([]byte, 12)
  247. li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
  248. f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
  249. })
  250. }
  251. func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
  252. runtime.LockOSThread()
  253. // Check if reader supports batch operations
  254. if batchReader, ok := reader.(BatchReader); ok {
  255. err := f.listenInBatch(batchReader, i)
  256. if err != nil {
  257. f.l.WithError(err).Error("Fatal error in batch packet reader, exiting goroutine")
  258. }
  259. return
  260. }
  261. // Fall back to single-packet mode
  262. packet := make([]byte, mtu)
  263. out := make([]byte, mtu)
  264. fwPacket := &firewall.Packet{}
  265. nb := make([]byte, 12, 12)
  266. conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
  267. for {
  268. n, err := reader.Read(packet)
  269. if err != nil {
  270. if errors.Is(err, os.ErrClosed) && f.closed.Load() {
  271. return
  272. }
  273. f.l.WithError(err).Error("Fatal error while reading outbound packet, exiting goroutine")
  274. return
  275. }
  276. f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
  277. }
  278. }
  279. // listenInBatch handles vectorized packet reading for improved performance
  280. func (f *Interface) listenInBatch(reader BatchReader, i int) error {
  281. // Allocate per-packet state and buffers for batch reading
  282. batchSize := f.batchSize
  283. if batchSize <= 0 {
  284. batchSize = 64 // Fallback to default if not configured
  285. }
  286. fwPackets := make([]*firewall.Packet, batchSize)
  287. outBuffers := make([][]byte, batchSize)
  288. nbBuffers := make([][]byte, batchSize)
  289. packets := make([][]byte, batchSize)
  290. sizes := make([]int, batchSize)
  291. for j := 0; j < batchSize; j++ {
  292. fwPackets[j] = &firewall.Packet{}
  293. outBuffers[j] = make([]byte, mtu)
  294. nbBuffers[j] = make([]byte, 12)
  295. packets[j] = make([]byte, mtu)
  296. }
  297. conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
  298. for {
  299. n, err := reader.BatchRead(packets, sizes)
  300. if err != nil {
  301. if errors.Is(err, os.ErrClosed) && f.closed.Load() {
  302. return nil
  303. }
  304. return fmt.Errorf("error while batch reading outbound packets: %w", err)
  305. }
  306. // Process each packet in the batch
  307. cache := conntrackCache.Get(f.l)
  308. for idx := 0; idx < n; idx++ {
  309. if sizes[idx] > 0 {
  310. // Use modulo to reuse fw packet state if batch is larger than our pre-allocated state
  311. stateIdx := idx % len(fwPackets)
  312. f.consumeInsidePacket(packets[idx][:sizes[idx]], fwPackets[stateIdx], nbBuffers[stateIdx], outBuffers[stateIdx], i, cache)
  313. }
  314. }
  315. }
  316. }
  317. // writeTunBatch attempts to write multiple packets to the TUN device using batch operations if supported
  318. func (f *Interface) writeTunBatch(q int, packets [][]byte) error {
  319. if len(packets) == 0 {
  320. return nil
  321. }
  322. // Check if the reader/writer supports batch operations
  323. if batchWriter, ok := f.readers[q].(BatchWriter); ok {
  324. _, err := batchWriter.BatchWrite(packets)
  325. return err
  326. }
  327. // Fall back to writing packets individually
  328. for _, packet := range packets {
  329. if _, err := f.readers[q].Write(packet); err != nil {
  330. return err
  331. }
  332. }
  333. return nil
  334. }
  335. // writeTun writes a single packet to the TUN device
  336. func (f *Interface) writeTun(q int, packet []byte) error {
  337. _, err := f.readers[q].Write(packet)
  338. return err
  339. }
  340. func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
  341. c.RegisterReloadCallback(f.reloadFirewall)
  342. c.RegisterReloadCallback(f.reloadSendRecvError)
  343. c.RegisterReloadCallback(f.reloadDisconnectInvalid)
  344. c.RegisterReloadCallback(f.reloadMisc)
  345. for _, udpConn := range f.writers {
  346. c.RegisterReloadCallback(udpConn.ReloadConfig)
  347. }
  348. }
  349. func (f *Interface) reloadDisconnectInvalid(c *config.C) {
  350. initial := c.InitialLoad()
  351. if initial || c.HasChanged("pki.disconnect_invalid") {
  352. f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true))
  353. if !initial {
  354. f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load())
  355. }
  356. }
  357. }
  358. func (f *Interface) reloadFirewall(c *config.C) {
  359. //TODO: need to trigger/detect if the certificate changed too
  360. if c.HasChanged("firewall") == false {
  361. f.l.Debug("No firewall config change detected")
  362. return
  363. }
  364. fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
  365. if err != nil {
  366. f.l.WithError(err).Error("Error while creating firewall during reload")
  367. return
  368. }
  369. oldFw := f.firewall
  370. conntrack := oldFw.Conntrack
  371. conntrack.Lock()
  372. defer conntrack.Unlock()
  373. fw.rulesVersion = oldFw.rulesVersion + 1
  374. // If rulesVersion is back to zero, we have wrapped all the way around. Be
  375. // safe and just reset conntrack in this case.
  376. if fw.rulesVersion == 0 {
  377. f.l.WithField("firewallHashes", fw.GetRuleHashes()).
  378. WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
  379. WithField("rulesVersion", fw.rulesVersion).
  380. Warn("firewall rulesVersion has overflowed, resetting conntrack")
  381. } else {
  382. fw.Conntrack = conntrack
  383. }
  384. f.firewall = fw
  385. oldFw.Destroy()
  386. f.l.WithField("firewallHashes", fw.GetRuleHashes()).
  387. WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
  388. WithField("rulesVersion", fw.rulesVersion).
  389. Info("New firewall has been installed")
  390. }
  391. func (f *Interface) reloadSendRecvError(c *config.C) {
  392. if c.InitialLoad() || c.HasChanged("listen.send_recv_error") {
  393. stringValue := c.GetString("listen.send_recv_error", "always")
  394. switch stringValue {
  395. case "always":
  396. f.sendRecvErrorConfig = sendRecvErrorAlways
  397. case "never":
  398. f.sendRecvErrorConfig = sendRecvErrorNever
  399. case "private":
  400. f.sendRecvErrorConfig = sendRecvErrorPrivate
  401. default:
  402. if c.GetBool("listen.send_recv_error", true) {
  403. f.sendRecvErrorConfig = sendRecvErrorAlways
  404. } else {
  405. f.sendRecvErrorConfig = sendRecvErrorNever
  406. }
  407. }
  408. f.l.WithField("sendRecvError", f.sendRecvErrorConfig.String()).
  409. Info("Loaded send_recv_error config")
  410. }
  411. }
  412. func (f *Interface) reloadMisc(c *config.C) {
  413. if c.HasChanged("counters.try_promote") {
  414. n := c.GetUint32("counters.try_promote", defaultPromoteEvery)
  415. f.tryPromoteEvery.Store(n)
  416. f.l.Info("counters.try_promote has changed")
  417. }
  418. if c.HasChanged("counters.requery_every_packets") {
  419. n := c.GetUint32("counters.requery_every_packets", defaultReQueryEvery)
  420. f.reQueryEvery.Store(n)
  421. f.l.Info("counters.requery_every_packets has changed")
  422. }
  423. if c.HasChanged("timers.requery_wait_duration") {
  424. n := c.GetDuration("timers.requery_wait_duration", defaultReQueryWait)
  425. f.reQueryWait.Store(int64(n))
  426. f.l.Info("timers.requery_wait_duration has changed")
  427. }
  428. }
  429. func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
  430. ticker := time.NewTicker(i)
  431. defer ticker.Stop()
  432. udpStats := udp.NewUDPStatsEmitter(f.writers)
  433. certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
  434. certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil)
  435. certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
  436. for {
  437. select {
  438. case <-ctx.Done():
  439. return
  440. case <-ticker.C:
  441. f.firewall.EmitStats()
  442. f.handshakeManager.EmitStats()
  443. udpStats()
  444. certState := f.pki.getCertState()
  445. defaultCrt := certState.GetDefaultCertificate()
  446. certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
  447. certInitiatingVersion.Update(int64(defaultCrt.Version()))
  448. // Report the max certificate version we are capable of using
  449. if certState.v2Cert != nil {
  450. certMaxVersion.Update(int64(certState.v2Cert.Version()))
  451. } else {
  452. certMaxVersion.Update(int64(certState.v1Cert.Version()))
  453. }
  454. }
  455. }
  456. }
  457. func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo {
  458. return f.hostMap.QueryVpnAddr(vpnIp)
  459. }
  460. func (f *Interface) GetCertState() *CertState {
  461. return f.pki.getCertState()
  462. }
  463. func (f *Interface) Close() error {
  464. f.closed.Store(true)
  465. for _, u := range f.writers {
  466. err := u.Close()
  467. if err != nil {
  468. f.l.WithError(err).Error("Error while closing udp socket")
  469. }
  470. }
  471. // Release the tun device
  472. return f.inside.Close()
  473. }