router.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. //go:build e2e_testing
  2. // +build e2e_testing
  3. package router
  4. import (
  5. "fmt"
  6. "net"
  7. "reflect"
  8. "strconv"
  9. "sync"
  10. "github.com/slackhq/nebula"
  11. )
  12. type R struct {
  13. // Simple map of the ip:port registered on a control to the control
  14. // Basically a router, right?
  15. controls map[string]*nebula.Control
  16. // A map for inbound packets for a control that doesn't know about this address
  17. inNat map[string]*nebula.Control
  18. // A last used map, if an inbound packet hit the inNat map then
  19. // all return packets should use the same last used inbound address for the outbound sender
  20. // map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver
  21. outNat map[string]net.UDPAddr
  22. // All interactions are locked to help serialize behavior
  23. sync.Mutex
  24. }
  25. type ExitType int
  26. const (
  27. // Keeps routing, the function will get called again on the next packet
  28. KeepRouting ExitType = 0
  29. // Does not route this packet and exits immediately
  30. ExitNow ExitType = 1
  31. // Routes this packet and exits immediately afterwards
  32. RouteAndExit ExitType = 2
  33. )
  34. type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) ExitType
  35. func NewR(controls ...*nebula.Control) *R {
  36. r := &R{
  37. controls: make(map[string]*nebula.Control),
  38. inNat: make(map[string]*nebula.Control),
  39. outNat: make(map[string]net.UDPAddr),
  40. }
  41. for _, c := range controls {
  42. addr := c.GetUDPAddr()
  43. if _, ok := r.controls[addr]; ok {
  44. panic("Duplicate listen address: " + addr)
  45. }
  46. r.controls[addr] = c
  47. }
  48. return r
  49. }
  50. // AddRoute will place the nebula controller at the ip and port specified.
  51. // It does not look at the addr attached to the instance.
  52. // If a route is used, this will behave like a NAT for the return path.
  53. // Rewriting the source ip:port to what was last sent to from the origin
  54. func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
  55. r.Lock()
  56. defer r.Unlock()
  57. inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))
  58. if _, ok := r.inNat[inAddr]; ok {
  59. panic("Duplicate listen address inNat: " + inAddr)
  60. }
  61. r.inNat[inAddr] = c
  62. }
  63. // OnceFrom will route a single packet from sender then return
  64. // If the router doesn't have the nebula controller for that address, we panic
  65. func (r *R) OnceFrom(sender *nebula.Control) {
  66. r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) ExitType {
  67. return RouteAndExit
  68. })
  69. }
  70. // RouteUntilTxTun will route for sender and return when a packet is seen on receivers tun
  71. // If the router doesn't have the nebula controller for that address, we panic
  72. func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []byte {
  73. tunTx := receiver.GetTunTxChan()
  74. udpTx := sender.GetUDPTxChan()
  75. for {
  76. select {
  77. // Maybe we already have something on the tun for us
  78. case b := <-tunTx:
  79. return b
  80. // Nope, lets push the sender along
  81. case p := <-udpTx:
  82. outAddr := sender.GetUDPAddr()
  83. r.Lock()
  84. inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
  85. c := r.getControl(outAddr, inAddr, p)
  86. if c == nil {
  87. r.Unlock()
  88. panic("No control for udp tx")
  89. }
  90. c.InjectUDPPacket(p)
  91. r.Unlock()
  92. }
  93. }
  94. }
  95. // RouteExitFunc will call the whatDo func with each udp packet from sender.
  96. // whatDo can return:
  97. // - exitNow: the packet will not be routed and this call will return immediately
  98. // - routeAndExit: this call will return immediately after routing the last packet from sender
  99. // - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
  100. func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
  101. h := &nebula.Header{}
  102. for {
  103. p := sender.GetFromUDP(true)
  104. r.Lock()
  105. if err := h.Parse(p.Data); err != nil {
  106. panic(err)
  107. }
  108. outAddr := sender.GetUDPAddr()
  109. inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
  110. receiver := r.getControl(outAddr, inAddr, p)
  111. if receiver == nil {
  112. r.Unlock()
  113. panic("Can't route for host: " + inAddr)
  114. }
  115. e := whatDo(p, receiver)
  116. switch e {
  117. case ExitNow:
  118. r.Unlock()
  119. return
  120. case RouteAndExit:
  121. receiver.InjectUDPPacket(p)
  122. r.Unlock()
  123. return
  124. case KeepRouting:
  125. receiver.InjectUDPPacket(p)
  126. default:
  127. panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
  128. }
  129. r.Unlock()
  130. }
  131. }
  132. // RouteUntilAfterMsgType will route for sender until a message type is seen and sent from sender
  133. // If the router doesn't have the nebula controller for that address, we panic
  134. func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) {
  135. h := &nebula.Header{}
  136. r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
  137. if err := h.Parse(p.Data); err != nil {
  138. panic(err)
  139. }
  140. if h.Type == msgType && h.Subtype == subType {
  141. return RouteAndExit
  142. }
  143. return KeepRouting
  144. })
  145. }
  146. // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
  147. // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
  148. // If the router doesn't have the nebula controller for that address, we panic
  149. func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) {
  150. if finish == KeepRouting {
  151. finish = RouteAndExit
  152. }
  153. r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
  154. if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
  155. return finish
  156. }
  157. return KeepRouting
  158. })
  159. }
  160. // RouteForAllExitFunc will route for every registered controller and calls the whatDo func with each udp packet from
  161. // whatDo can return:
  162. // - exitNow: the packet will not be routed and this call will return immediately
  163. // - routeAndExit: this call will return immediately after routing the last packet from sender
  164. // - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
  165. func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
  166. sc := make([]reflect.SelectCase, len(r.controls))
  167. cm := make([]*nebula.Control, len(r.controls))
  168. i := 0
  169. for _, c := range r.controls {
  170. sc[i] = reflect.SelectCase{
  171. Dir: reflect.SelectRecv,
  172. Chan: reflect.ValueOf(c.GetUDPTxChan()),
  173. Send: reflect.Value{},
  174. }
  175. cm[i] = c
  176. i++
  177. }
  178. for {
  179. x, rx, _ := reflect.Select(sc)
  180. r.Lock()
  181. p := rx.Interface().(*nebula.UdpPacket)
  182. outAddr := cm[x].GetUDPAddr()
  183. inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
  184. receiver := r.getControl(outAddr, inAddr, p)
  185. if receiver == nil {
  186. r.Unlock()
  187. panic("Can't route for host: " + inAddr)
  188. }
  189. e := whatDo(p, receiver)
  190. switch e {
  191. case ExitNow:
  192. r.Unlock()
  193. return
  194. case RouteAndExit:
  195. receiver.InjectUDPPacket(p)
  196. r.Unlock()
  197. return
  198. case KeepRouting:
  199. receiver.InjectUDPPacket(p)
  200. default:
  201. panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
  202. }
  203. r.Unlock()
  204. }
  205. }
  206. // FlushAll will route for every registered controller, exiting once there are no packets left to route
  207. func (r *R) FlushAll() {
  208. sc := make([]reflect.SelectCase, len(r.controls))
  209. cm := make([]*nebula.Control, len(r.controls))
  210. i := 0
  211. for _, c := range r.controls {
  212. sc[i] = reflect.SelectCase{
  213. Dir: reflect.SelectRecv,
  214. Chan: reflect.ValueOf(c.GetUDPTxChan()),
  215. Send: reflect.Value{},
  216. }
  217. cm[i] = c
  218. i++
  219. }
  220. // Add a default case to exit when nothing is left to send
  221. sc = append(sc, reflect.SelectCase{
  222. Dir: reflect.SelectDefault,
  223. Chan: reflect.Value{},
  224. Send: reflect.Value{},
  225. })
  226. for {
  227. x, rx, ok := reflect.Select(sc)
  228. if !ok {
  229. return
  230. }
  231. r.Lock()
  232. p := rx.Interface().(*nebula.UdpPacket)
  233. outAddr := cm[x].GetUDPAddr()
  234. inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
  235. receiver := r.getControl(outAddr, inAddr, p)
  236. if receiver == nil {
  237. r.Unlock()
  238. panic("Can't route for host: " + inAddr)
  239. }
  240. r.Unlock()
  241. }
  242. }
  243. // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
  244. // This is an internal router function, the caller must hold the lock
  245. func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control {
  246. if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
  247. p.FromIp = newAddr.IP
  248. p.FromPort = uint16(newAddr.Port)
  249. }
  250. c, ok := r.inNat[toAddr]
  251. if ok {
  252. sHost, sPort, err := net.SplitHostPort(toAddr)
  253. if err != nil {
  254. panic(err)
  255. }
  256. port, err := strconv.Atoi(sPort)
  257. if err != nil {
  258. panic(err)
  259. }
  260. r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{
  261. IP: net.ParseIP(sHost),
  262. Port: port,
  263. }
  264. return c
  265. }
  266. return r.controls[toAddr]
  267. }