tun_disabled.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. package nebula
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "io"
  6. "net"
  7. "strings"
  8. "github.com/rcrowley/go-metrics"
  9. "github.com/sirupsen/logrus"
  10. log "github.com/sirupsen/logrus"
  11. )
  12. type disabledTun struct {
  13. read chan []byte
  14. cidr *net.IPNet
  15. logger *log.Logger
  16. // Track these metrics since we don't have the tun device to do it for us
  17. tx metrics.Counter
  18. rx metrics.Counter
  19. }
  20. func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *log.Logger) *disabledTun {
  21. tun := &disabledTun{
  22. cidr: cidr,
  23. read: make(chan []byte, queueLen),
  24. logger: l,
  25. }
  26. if metricsEnabled {
  27. tun.tx = metrics.GetOrRegisterCounter("messages.tx.message", nil)
  28. tun.rx = metrics.GetOrRegisterCounter("messages.rx.message", nil)
  29. } else {
  30. tun.tx = &metrics.NilCounter{}
  31. tun.rx = &metrics.NilCounter{}
  32. }
  33. return tun
  34. }
  35. func (*disabledTun) Activate() error {
  36. return nil
  37. }
  38. func (t *disabledTun) CidrNet() *net.IPNet {
  39. return t.cidr
  40. }
  41. func (*disabledTun) DeviceName() string {
  42. return "disabled"
  43. }
  44. func (t *disabledTun) Read(b []byte) (int, error) {
  45. r, ok := <-t.read
  46. if !ok {
  47. return 0, io.EOF
  48. }
  49. if len(r) > len(b) {
  50. return 0, fmt.Errorf("packet larger than mtu: %d > %d bytes", len(r), len(b))
  51. }
  52. t.tx.Inc(1)
  53. if l.Level >= logrus.DebugLevel {
  54. t.logger.WithField("raw", prettyPacket(r)).Debugf("Write payload")
  55. }
  56. return copy(b, r), nil
  57. }
  58. func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
  59. // Return early if this is not a simple ICMP Echo Request
  60. if !(len(b) >= 28 && len(b) <= mtu && b[0] == 0x45 && b[9] == 0x01 && b[20] == 0x08) {
  61. return false
  62. }
  63. // We don't support fragmented packets
  64. if b[7] != 0 || (b[6]&0x2F != 0) {
  65. return false
  66. }
  67. buf := make([]byte, len(b))
  68. copy(buf, b)
  69. // Swap dest / src IPs and recalculate checksum
  70. ipv4 := buf[0:20]
  71. copy(ipv4[12:16], b[16:20])
  72. copy(ipv4[16:20], b[12:16])
  73. ipv4[10] = 0
  74. ipv4[11] = 0
  75. binary.BigEndian.PutUint16(ipv4[10:], ipChecksum(ipv4))
  76. // Change type to ICMP Echo Reply and recalculate checksum
  77. icmp := buf[20:]
  78. icmp[0] = 0
  79. icmp[2] = 0
  80. icmp[3] = 0
  81. binary.BigEndian.PutUint16(icmp[2:], ipChecksum(icmp))
  82. // attempt to write it, but don't block
  83. select {
  84. case t.read <- buf:
  85. default:
  86. t.logger.Debugf("tun_disabled: dropped ICMP Echo Reply response")
  87. }
  88. return true
  89. }
  90. func (t *disabledTun) Write(b []byte) (int, error) {
  91. t.rx.Inc(1)
  92. // Check for ICMP Echo Request before spending time doing the full parsing
  93. if t.handleICMPEchoRequest(b) {
  94. if l.Level >= logrus.DebugLevel {
  95. t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
  96. }
  97. } else if l.Level >= logrus.DebugLevel {
  98. t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
  99. }
  100. return len(b), nil
  101. }
  102. func (t *disabledTun) WriteRaw(b []byte) error {
  103. _, err := t.Write(b)
  104. return err
  105. }
  106. func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
  107. return t, nil
  108. }
  109. func (t *disabledTun) Close() error {
  110. if t.read != nil {
  111. close(t.read)
  112. t.read = nil
  113. }
  114. return nil
  115. }
  116. type prettyPacket []byte
  117. func (p prettyPacket) String() string {
  118. var s strings.Builder
  119. for i, b := range p {
  120. if i > 0 && i%8 == 0 {
  121. s.WriteString(" ")
  122. }
  123. s.WriteString(fmt.Sprintf("%02x ", b))
  124. }
  125. return s.String()
  126. }
  127. func ipChecksum(b []byte) uint16 {
  128. var c uint32
  129. sz := len(b) - 1
  130. for i := 0; i < sz; i += 2 {
  131. c += uint32(b[i]) << 8
  132. c += uint32(b[i+1])
  133. }
  134. if sz%2 == 0 {
  135. c += uint32(b[sz]) << 8
  136. }
  137. for (c >> 16) > 0 {
  138. c = (c & 0xffff) + (c >> 16)
  139. }
  140. return ^uint16(c)
  141. }