tun_disabled.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. package overlay
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "io"
  6. "net"
  7. "strings"
  8. "github.com/rcrowley/go-metrics"
  9. "github.com/sirupsen/logrus"
  10. "github.com/slackhq/nebula/iputil"
  11. )
  12. type disabledTun struct {
  13. read chan []byte
  14. cidr *net.IPNet
  15. // Track these metrics since we don't have the tun device to do it for us
  16. tx metrics.Counter
  17. rx metrics.Counter
  18. l *logrus.Logger
  19. }
  20. func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
  21. tun := &disabledTun{
  22. cidr: cidr,
  23. read: make(chan []byte, queueLen),
  24. l: l,
  25. }
  26. if metricsEnabled {
  27. tun.tx = metrics.GetOrRegisterCounter("messages.tx.message", nil)
  28. tun.rx = metrics.GetOrRegisterCounter("messages.rx.message", nil)
  29. } else {
  30. tun.tx = &metrics.NilCounter{}
  31. tun.rx = &metrics.NilCounter{}
  32. }
  33. return tun
  34. }
  35. func (*disabledTun) Activate() error {
  36. return nil
  37. }
  38. func (*disabledTun) RouteFor(iputil.VpnIp) iputil.VpnIp {
  39. return 0
  40. }
  41. func (t *disabledTun) CidrNet() *net.IPNet {
  42. return t.cidr
  43. }
  44. func (*disabledTun) DeviceName() string {
  45. return "disabled"
  46. }
  47. func (t *disabledTun) Read(b []byte) (int, error) {
  48. r, ok := <-t.read
  49. if !ok {
  50. return 0, io.EOF
  51. }
  52. if len(r) > len(b) {
  53. return 0, fmt.Errorf("packet larger than mtu: %d > %d bytes", len(r), len(b))
  54. }
  55. t.tx.Inc(1)
  56. if t.l.Level >= logrus.DebugLevel {
  57. t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload")
  58. }
  59. return copy(b, r), nil
  60. }
  61. func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
  62. // Return early if this is not a simple ICMP Echo Request
  63. //TODO: make constants out of these
  64. if !(len(b) >= 28 && len(b) <= 9001 && b[0] == 0x45 && b[9] == 0x01 && b[20] == 0x08) {
  65. return false
  66. }
  67. // We don't support fragmented packets
  68. if b[7] != 0 || (b[6]&0x2F != 0) {
  69. return false
  70. }
  71. buf := make([]byte, len(b))
  72. copy(buf, b)
  73. // Swap dest / src IPs and recalculate checksum
  74. ipv4 := buf[0:20]
  75. copy(ipv4[12:16], b[16:20])
  76. copy(ipv4[16:20], b[12:16])
  77. ipv4[10] = 0
  78. ipv4[11] = 0
  79. binary.BigEndian.PutUint16(ipv4[10:], ipChecksum(ipv4))
  80. // Change type to ICMP Echo Reply and recalculate checksum
  81. icmp := buf[20:]
  82. icmp[0] = 0
  83. icmp[2] = 0
  84. icmp[3] = 0
  85. binary.BigEndian.PutUint16(icmp[2:], ipChecksum(icmp))
  86. // attempt to write it, but don't block
  87. select {
  88. case t.read <- buf:
  89. default:
  90. t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response")
  91. }
  92. return true
  93. }
  94. func (t *disabledTun) Write(b []byte) (int, error) {
  95. t.rx.Inc(1)
  96. // Check for ICMP Echo Request before spending time doing the full parsing
  97. if t.handleICMPEchoRequest(b) {
  98. if t.l.Level >= logrus.DebugLevel {
  99. t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
  100. }
  101. } else if t.l.Level >= logrus.DebugLevel {
  102. t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
  103. }
  104. return len(b), nil
  105. }
  106. func (t *disabledTun) WriteRaw(b []byte) error {
  107. _, err := t.Write(b)
  108. return err
  109. }
  110. func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
  111. return t, nil
  112. }
  113. func (t *disabledTun) Close() error {
  114. if t.read != nil {
  115. close(t.read)
  116. t.read = nil
  117. }
  118. return nil
  119. }
  120. type prettyPacket []byte
  121. func (p prettyPacket) String() string {
  122. var s strings.Builder
  123. for i, b := range p {
  124. if i > 0 && i%8 == 0 {
  125. s.WriteString(" ")
  126. }
  127. s.WriteString(fmt.Sprintf("%02x ", b))
  128. }
  129. return s.String()
  130. }
  131. func ipChecksum(b []byte) uint16 {
  132. var c uint32
  133. sz := len(b) - 1
  134. for i := 0; i < sz; i += 2 {
  135. c += uint32(b[i]) << 8
  136. c += uint32(b[i+1])
  137. }
  138. if sz%2 == 0 {
  139. c += uint32(b[sz]) << 8
  140. }
  141. for (c >> 16) > 0 {
  142. c = (c & 0xffff) + (c >> 16)
  143. }
  144. return ^uint16(c)
  145. }