2
0

router.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. //go:build e2e_testing
  2. // +build e2e_testing
  3. package router
  4. import (
  5. "fmt"
  6. "net"
  7. "reflect"
  8. "strconv"
  9. "sync"
  10. "github.com/slackhq/nebula"
  11. "github.com/slackhq/nebula/header"
  12. "github.com/slackhq/nebula/udp"
  13. )
  14. type R struct {
  15. // Simple map of the ip:port registered on a control to the control
  16. // Basically a router, right?
  17. controls map[string]*nebula.Control
  18. // A map for inbound packets for a control that doesn't know about this address
  19. inNat map[string]*nebula.Control
  20. // A last used map, if an inbound packet hit the inNat map then
  21. // all return packets should use the same last used inbound address for the outbound sender
  22. // map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver
  23. outNat map[string]net.UDPAddr
  24. // All interactions are locked to help serialize behavior
  25. sync.Mutex
  26. }
  27. type ExitType int
  28. const (
  29. // Keeps routing, the function will get called again on the next packet
  30. KeepRouting ExitType = 0
  31. // Does not route this packet and exits immediately
  32. ExitNow ExitType = 1
  33. // Routes this packet and exits immediately afterwards
  34. RouteAndExit ExitType = 2
  35. )
  36. type ExitFunc func(packet *udp.Packet, receiver *nebula.Control) ExitType
  37. func NewR(controls ...*nebula.Control) *R {
  38. r := &R{
  39. controls: make(map[string]*nebula.Control),
  40. inNat: make(map[string]*nebula.Control),
  41. outNat: make(map[string]net.UDPAddr),
  42. }
  43. for _, c := range controls {
  44. addr := c.GetUDPAddr()
  45. if _, ok := r.controls[addr]; ok {
  46. panic("Duplicate listen address: " + addr)
  47. }
  48. r.controls[addr] = c
  49. }
  50. return r
  51. }
  52. // AddRoute will place the nebula controller at the ip and port specified.
  53. // It does not look at the addr attached to the instance.
  54. // If a route is used, this will behave like a NAT for the return path.
  55. // Rewriting the source ip:port to what was last sent to from the origin
  56. func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
  57. r.Lock()
  58. defer r.Unlock()
  59. inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))
  60. if _, ok := r.inNat[inAddr]; ok {
  61. panic("Duplicate listen address inNat: " + inAddr)
  62. }
  63. r.inNat[inAddr] = c
  64. }
  65. // OnceFrom will route a single packet from sender then return
  66. // If the router doesn't have the nebula controller for that address, we panic
  67. func (r *R) OnceFrom(sender *nebula.Control) {
  68. r.RouteExitFunc(sender, func(*udp.Packet, *nebula.Control) ExitType {
  69. return RouteAndExit
  70. })
  71. }
  72. // RouteUntilTxTun will route for sender and return when a packet is seen on receivers tun
  73. // If the router doesn't have the nebula controller for that address, we panic
  74. func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []byte {
  75. tunTx := receiver.GetTunTxChan()
  76. udpTx := sender.GetUDPTxChan()
  77. for {
  78. select {
  79. // Maybe we already have something on the tun for us
  80. case b := <-tunTx:
  81. return b
  82. // Nope, lets push the sender along
  83. case p := <-udpTx:
  84. outAddr := sender.GetUDPAddr()
  85. r.Lock()
  86. inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
  87. c := r.getControl(outAddr, inAddr, p)
  88. if c == nil {
  89. r.Unlock()
  90. panic("No control for udp tx")
  91. }
  92. c.InjectUDPPacket(p)
  93. r.Unlock()
  94. }
  95. }
  96. }
  97. // RouteExitFunc will call the whatDo func with each udp packet from sender.
  98. // whatDo can return:
  99. // - exitNow: the packet will not be routed and this call will return immediately
  100. // - routeAndExit: this call will return immediately after routing the last packet from sender
  101. // - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
  102. func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
  103. h := &header.H{}
  104. for {
  105. p := sender.GetFromUDP(true)
  106. r.Lock()
  107. if err := h.Parse(p.Data); err != nil {
  108. panic(err)
  109. }
  110. outAddr := sender.GetUDPAddr()
  111. inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
  112. receiver := r.getControl(outAddr, inAddr, p)
  113. if receiver == nil {
  114. r.Unlock()
  115. panic("Can't route for host: " + inAddr)
  116. }
  117. e := whatDo(p, receiver)
  118. switch e {
  119. case ExitNow:
  120. r.Unlock()
  121. return
  122. case RouteAndExit:
  123. receiver.InjectUDPPacket(p)
  124. r.Unlock()
  125. return
  126. case KeepRouting:
  127. receiver.InjectUDPPacket(p)
  128. default:
  129. panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
  130. }
  131. r.Unlock()
  132. }
  133. }
  134. // RouteUntilAfterMsgType will route for sender until a message type is seen and sent from sender
  135. // If the router doesn't have the nebula controller for that address, we panic
  136. func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType header.MessageType, subType header.MessageSubType) {
  137. h := &header.H{}
  138. r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
  139. if err := h.Parse(p.Data); err != nil {
  140. panic(err)
  141. }
  142. if h.Type == msgType && h.Subtype == subType {
  143. return RouteAndExit
  144. }
  145. return KeepRouting
  146. })
  147. }
  148. // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
  149. // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
  150. // If the router doesn't have the nebula controller for that address, we panic
  151. func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) {
  152. if finish == KeepRouting {
  153. finish = RouteAndExit
  154. }
  155. r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
  156. if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
  157. return finish
  158. }
  159. return KeepRouting
  160. })
  161. }
  162. // RouteForAllExitFunc will route for every registered controller and calls the whatDo func with each udp packet from
  163. // whatDo can return:
  164. // - exitNow: the packet will not be routed and this call will return immediately
  165. // - routeAndExit: this call will return immediately after routing the last packet from sender
  166. // - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
  167. func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
  168. sc := make([]reflect.SelectCase, len(r.controls))
  169. cm := make([]*nebula.Control, len(r.controls))
  170. i := 0
  171. for _, c := range r.controls {
  172. sc[i] = reflect.SelectCase{
  173. Dir: reflect.SelectRecv,
  174. Chan: reflect.ValueOf(c.GetUDPTxChan()),
  175. Send: reflect.Value{},
  176. }
  177. cm[i] = c
  178. i++
  179. }
  180. for {
  181. x, rx, _ := reflect.Select(sc)
  182. r.Lock()
  183. p := rx.Interface().(*udp.Packet)
  184. outAddr := cm[x].GetUDPAddr()
  185. inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
  186. receiver := r.getControl(outAddr, inAddr, p)
  187. if receiver == nil {
  188. r.Unlock()
  189. panic("Can't route for host: " + inAddr)
  190. }
  191. e := whatDo(p, receiver)
  192. switch e {
  193. case ExitNow:
  194. r.Unlock()
  195. return
  196. case RouteAndExit:
  197. receiver.InjectUDPPacket(p)
  198. r.Unlock()
  199. return
  200. case KeepRouting:
  201. receiver.InjectUDPPacket(p)
  202. default:
  203. panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
  204. }
  205. r.Unlock()
  206. }
  207. }
  208. // FlushAll will route for every registered controller, exiting once there are no packets left to route
  209. func (r *R) FlushAll() {
  210. sc := make([]reflect.SelectCase, len(r.controls))
  211. cm := make([]*nebula.Control, len(r.controls))
  212. i := 0
  213. for _, c := range r.controls {
  214. sc[i] = reflect.SelectCase{
  215. Dir: reflect.SelectRecv,
  216. Chan: reflect.ValueOf(c.GetUDPTxChan()),
  217. Send: reflect.Value{},
  218. }
  219. cm[i] = c
  220. i++
  221. }
  222. // Add a default case to exit when nothing is left to send
  223. sc = append(sc, reflect.SelectCase{
  224. Dir: reflect.SelectDefault,
  225. Chan: reflect.Value{},
  226. Send: reflect.Value{},
  227. })
  228. for {
  229. x, rx, ok := reflect.Select(sc)
  230. if !ok {
  231. return
  232. }
  233. r.Lock()
  234. p := rx.Interface().(*udp.Packet)
  235. outAddr := cm[x].GetUDPAddr()
  236. inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
  237. receiver := r.getControl(outAddr, inAddr, p)
  238. if receiver == nil {
  239. r.Unlock()
  240. panic("Can't route for host: " + inAddr)
  241. }
  242. r.Unlock()
  243. }
  244. }
  245. // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
  246. // This is an internal router function, the caller must hold the lock
  247. func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control {
  248. if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
  249. p.FromIp = newAddr.IP
  250. p.FromPort = uint16(newAddr.Port)
  251. }
  252. c, ok := r.inNat[toAddr]
  253. if ok {
  254. sHost, sPort, err := net.SplitHostPort(toAddr)
  255. if err != nil {
  256. panic(err)
  257. }
  258. port, err := strconv.Atoi(sPort)
  259. if err != nil {
  260. panic(err)
  261. }
  262. r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{
  263. IP: net.ParseIP(sHost),
  264. Port: port,
  265. }
  266. return c
  267. }
  268. return r.controls[toAddr]
  269. }