tun_disabled.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. package nebula
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "io"
  6. "net"
  7. "strings"
  8. "github.com/rcrowley/go-metrics"
  9. "github.com/sirupsen/logrus"
  10. )
  11. type disabledTun struct {
  12. read chan []byte
  13. cidr *net.IPNet
  14. // Track these metrics since we don't have the tun device to do it for us
  15. tx metrics.Counter
  16. rx metrics.Counter
  17. l *logrus.Logger
  18. }
  19. func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
  20. tun := &disabledTun{
  21. cidr: cidr,
  22. read: make(chan []byte, queueLen),
  23. l: l,
  24. }
  25. if metricsEnabled {
  26. tun.tx = metrics.GetOrRegisterCounter("messages.tx.message", nil)
  27. tun.rx = metrics.GetOrRegisterCounter("messages.rx.message", nil)
  28. } else {
  29. tun.tx = &metrics.NilCounter{}
  30. tun.rx = &metrics.NilCounter{}
  31. }
  32. return tun
  33. }
  34. func (*disabledTun) Activate() error {
  35. return nil
  36. }
  37. func (t *disabledTun) CidrNet() *net.IPNet {
  38. return t.cidr
  39. }
  40. func (*disabledTun) DeviceName() string {
  41. return "disabled"
  42. }
  43. func (t *disabledTun) Read(b []byte) (int, error) {
  44. r, ok := <-t.read
  45. if !ok {
  46. return 0, io.EOF
  47. }
  48. if len(r) > len(b) {
  49. return 0, fmt.Errorf("packet larger than mtu: %d > %d bytes", len(r), len(b))
  50. }
  51. t.tx.Inc(1)
  52. if t.l.Level >= logrus.DebugLevel {
  53. t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload")
  54. }
  55. return copy(b, r), nil
  56. }
  57. func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
  58. // Return early if this is not a simple ICMP Echo Request
  59. if !(len(b) >= 28 && len(b) <= mtu && b[0] == 0x45 && b[9] == 0x01 && b[20] == 0x08) {
  60. return false
  61. }
  62. // We don't support fragmented packets
  63. if b[7] != 0 || (b[6]&0x2F != 0) {
  64. return false
  65. }
  66. buf := make([]byte, len(b))
  67. copy(buf, b)
  68. // Swap dest / src IPs and recalculate checksum
  69. ipv4 := buf[0:20]
  70. copy(ipv4[12:16], b[16:20])
  71. copy(ipv4[16:20], b[12:16])
  72. ipv4[10] = 0
  73. ipv4[11] = 0
  74. binary.BigEndian.PutUint16(ipv4[10:], ipChecksum(ipv4))
  75. // Change type to ICMP Echo Reply and recalculate checksum
  76. icmp := buf[20:]
  77. icmp[0] = 0
  78. icmp[2] = 0
  79. icmp[3] = 0
  80. binary.BigEndian.PutUint16(icmp[2:], ipChecksum(icmp))
  81. // attempt to write it, but don't block
  82. select {
  83. case t.read <- buf:
  84. default:
  85. t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response")
  86. }
  87. return true
  88. }
  89. func (t *disabledTun) Write(b []byte) (int, error) {
  90. t.rx.Inc(1)
  91. // Check for ICMP Echo Request before spending time doing the full parsing
  92. if t.handleICMPEchoRequest(b) {
  93. if t.l.Level >= logrus.DebugLevel {
  94. t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
  95. }
  96. } else if t.l.Level >= logrus.DebugLevel {
  97. t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
  98. }
  99. return len(b), nil
  100. }
  101. func (t *disabledTun) WriteRaw(b []byte) error {
  102. _, err := t.Write(b)
  103. return err
  104. }
  105. func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
  106. return t, nil
  107. }
  108. func (t *disabledTun) Close() error {
  109. if t.read != nil {
  110. close(t.read)
  111. t.read = nil
  112. }
  113. return nil
  114. }
  115. type prettyPacket []byte
  116. func (p prettyPacket) String() string {
  117. var s strings.Builder
  118. for i, b := range p {
  119. if i > 0 && i%8 == 0 {
  120. s.WriteString(" ")
  121. }
  122. s.WriteString(fmt.Sprintf("%02x ", b))
  123. }
  124. return s.String()
  125. }
  126. func ipChecksum(b []byte) uint16 {
  127. var c uint32
  128. sz := len(b) - 1
  129. for i := 0; i < sz; i += 2 {
  130. c += uint32(b[i]) << 8
  131. c += uint32(b[i+1])
  132. }
  133. if sz%2 == 0 {
  134. c += uint32(b[sz]) << 8
  135. }
  136. for (c >> 16) > 0 {
  137. c = (c & 0xffff) + (c >> 16)
  138. }
  139. return ^uint16(c)
  140. }