router.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. // +build e2e_testing
  2. package router
  3. import (
  4. "fmt"
  5. "net"
  6. "reflect"
  7. "strconv"
  8. "sync"
  9. "github.com/slackhq/nebula"
  10. )
  11. type R struct {
  12. // Simple map of the ip:port registered on a control to the control
  13. // Basically a router, right?
  14. controls map[string]*nebula.Control
  15. // A map for inbound packets for a control that doesn't know about this address
  16. inNat map[string]*nebula.Control
  17. // A last used map, if an inbound packet hit the inNat map then
  18. // all return packets should use the same last used inbound address for the outbound sender
  19. // map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver
  20. outNat map[string]net.UDPAddr
  21. // All interactions are locked to help serialize behavior
  22. sync.Mutex
  23. }
  24. type ExitType int
  25. const (
  26. // Keeps routing, the function will get called again on the next packet
  27. KeepRouting ExitType = 0
  28. // Does not route this packet and exits immediately
  29. ExitNow ExitType = 1
  30. // Routes this packet and exits immediately afterwards
  31. RouteAndExit ExitType = 2
  32. )
  33. type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) ExitType
  34. func NewR(controls ...*nebula.Control) *R {
  35. r := &R{
  36. controls: make(map[string]*nebula.Control),
  37. inNat: make(map[string]*nebula.Control),
  38. outNat: make(map[string]net.UDPAddr),
  39. }
  40. for _, c := range controls {
  41. addr := c.GetUDPAddr()
  42. if _, ok := r.controls[addr]; ok {
  43. panic("Duplicate listen address: " + addr)
  44. }
  45. r.controls[addr] = c
  46. }
  47. return r
  48. }
  49. // AddRoute will place the nebula controller at the ip and port specified.
  50. // It does not look at the addr attached to the instance.
  51. // If a route is used, this will behave like a NAT for the return path.
  52. // Rewriting the source ip:port to what was last sent to from the origin
  53. func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
  54. r.Lock()
  55. defer r.Unlock()
  56. inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))
  57. if _, ok := r.inNat[inAddr]; ok {
  58. panic("Duplicate listen address inNat: " + inAddr)
  59. }
  60. r.inNat[inAddr] = c
  61. }
  62. // OnceFrom will route a single packet from sender then return
  63. // If the router doesn't have the nebula controller for that address, we panic
  64. func (r *R) OnceFrom(sender *nebula.Control) {
  65. r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) ExitType {
  66. return RouteAndExit
  67. })
  68. }
  69. // RouteUntilTxTun will route for sender and return when a packet is seen on receivers tun
  70. // If the router doesn't have the nebula controller for that address, we panic
  71. func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []byte {
  72. tunTx := receiver.GetTunTxChan()
  73. udpTx := sender.GetUDPTxChan()
  74. for {
  75. select {
  76. // Maybe we already have something on the tun for us
  77. case b := <-tunTx:
  78. return b
  79. // Nope, lets push the sender along
  80. case p := <-udpTx:
  81. outAddr := sender.GetUDPAddr()
  82. r.Lock()
  83. inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
  84. c := r.getControl(outAddr, inAddr, p)
  85. if c == nil {
  86. r.Unlock()
  87. panic("No control for udp tx")
  88. }
  89. c.InjectUDPPacket(p)
  90. r.Unlock()
  91. }
  92. }
  93. }
  94. // RouteExitFunc will call the whatDo func with each udp packet from sender.
  95. // whatDo can return:
  96. // - exitNow: the packet will not be routed and this call will return immediately
  97. // - routeAndExit: this call will return immediately after routing the last packet from sender
  98. // - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
  99. func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
  100. h := &nebula.Header{}
  101. for {
  102. p := sender.GetFromUDP(true)
  103. r.Lock()
  104. if err := h.Parse(p.Data); err != nil {
  105. panic(err)
  106. }
  107. outAddr := sender.GetUDPAddr()
  108. inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
  109. receiver := r.getControl(outAddr, inAddr, p)
  110. if receiver == nil {
  111. r.Unlock()
  112. panic("Can't route for host: " + inAddr)
  113. }
  114. e := whatDo(p, receiver)
  115. switch e {
  116. case ExitNow:
  117. r.Unlock()
  118. return
  119. case RouteAndExit:
  120. receiver.InjectUDPPacket(p)
  121. r.Unlock()
  122. return
  123. case KeepRouting:
  124. receiver.InjectUDPPacket(p)
  125. default:
  126. panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
  127. }
  128. r.Unlock()
  129. }
  130. }
  131. // RouteUntilAfterMsgType will route for sender until a message type is seen and sent from sender
  132. // If the router doesn't have the nebula controller for that address, we panic
  133. func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) {
  134. h := &nebula.Header{}
  135. r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
  136. if err := h.Parse(p.Data); err != nil {
  137. panic(err)
  138. }
  139. if h.Type == msgType && h.Subtype == subType {
  140. return RouteAndExit
  141. }
  142. return KeepRouting
  143. })
  144. }
  145. // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
  146. // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
  147. // If the router doesn't have the nebula controller for that address, we panic
  148. func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) {
  149. if finish == KeepRouting {
  150. finish = RouteAndExit
  151. }
  152. r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
  153. if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
  154. return finish
  155. }
  156. return KeepRouting
  157. })
  158. }
  159. // RouteForAllExitFunc will route for every registered controller and calls the whatDo func with each udp packet from
  160. // whatDo can return:
  161. // - exitNow: the packet will not be routed and this call will return immediately
  162. // - routeAndExit: this call will return immediately after routing the last packet from sender
  163. // - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
  164. func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
  165. sc := make([]reflect.SelectCase, len(r.controls))
  166. cm := make([]*nebula.Control, len(r.controls))
  167. i := 0
  168. for _, c := range r.controls {
  169. sc[i] = reflect.SelectCase{
  170. Dir: reflect.SelectRecv,
  171. Chan: reflect.ValueOf(c.GetUDPTxChan()),
  172. Send: reflect.Value{},
  173. }
  174. cm[i] = c
  175. i++
  176. }
  177. for {
  178. x, rx, _ := reflect.Select(sc)
  179. r.Lock()
  180. p := rx.Interface().(*nebula.UdpPacket)
  181. outAddr := cm[x].GetUDPAddr()
  182. inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
  183. receiver := r.getControl(outAddr, inAddr, p)
  184. if receiver == nil {
  185. r.Unlock()
  186. panic("Can't route for host: " + inAddr)
  187. }
  188. e := whatDo(p, receiver)
  189. switch e {
  190. case ExitNow:
  191. r.Unlock()
  192. return
  193. case RouteAndExit:
  194. receiver.InjectUDPPacket(p)
  195. r.Unlock()
  196. return
  197. case KeepRouting:
  198. receiver.InjectUDPPacket(p)
  199. default:
  200. panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
  201. }
  202. r.Unlock()
  203. }
  204. }
  205. // FlushAll will route for every registered controller, exiting once there are no packets left to route
  206. func (r *R) FlushAll() {
  207. sc := make([]reflect.SelectCase, len(r.controls))
  208. cm := make([]*nebula.Control, len(r.controls))
  209. i := 0
  210. for _, c := range r.controls {
  211. sc[i] = reflect.SelectCase{
  212. Dir: reflect.SelectRecv,
  213. Chan: reflect.ValueOf(c.GetUDPTxChan()),
  214. Send: reflect.Value{},
  215. }
  216. cm[i] = c
  217. i++
  218. }
  219. // Add a default case to exit when nothing is left to send
  220. sc = append(sc, reflect.SelectCase{
  221. Dir: reflect.SelectDefault,
  222. Chan: reflect.Value{},
  223. Send: reflect.Value{},
  224. })
  225. for {
  226. x, rx, ok := reflect.Select(sc)
  227. if !ok {
  228. return
  229. }
  230. r.Lock()
  231. p := rx.Interface().(*nebula.UdpPacket)
  232. outAddr := cm[x].GetUDPAddr()
  233. inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
  234. receiver := r.getControl(outAddr, inAddr, p)
  235. if receiver == nil {
  236. r.Unlock()
  237. panic("Can't route for host: " + inAddr)
  238. }
  239. r.Unlock()
  240. }
  241. }
  242. // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
  243. // This is an internal router function, the caller must hold the lock
  244. func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control {
  245. if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
  246. p.FromIp = newAddr.IP
  247. p.FromPort = uint16(newAddr.Port)
  248. }
  249. c, ok := r.inNat[toAddr]
  250. if ok {
  251. sHost, sPort, err := net.SplitHostPort(toAddr)
  252. if err != nil {
  253. panic(err)
  254. }
  255. port, err := strconv.Atoi(sPort)
  256. if err != nil {
  257. panic(err)
  258. }
  259. r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{
  260. IP: net.ParseIP(sHost),
  261. Port: port,
  262. }
  263. return c
  264. }
  265. return r.controls[toAddr]
  266. }