connection_state.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. package nebula
  2. import (
  3. "crypto/rand"
  4. "encoding/json"
  5. "fmt"
  6. "sync"
  7. "sync/atomic"
  8. "github.com/flynn/noise"
  9. "github.com/sirupsen/logrus"
  10. "github.com/slackhq/nebula/cert"
  11. "github.com/slackhq/nebula/noiseutil"
  12. )
  13. // ReplayWindow controls the size of the sliding window used to detect replays.
  14. // High-bandwidth links with GRO/GSO can reorder more than a thousand packets in
  15. // flight, so keep this comfortably above the largest expected burst.
  16. const ReplayWindow = 32768
  17. type ConnectionState struct {
  18. eKey *NebulaCipherState
  19. dKey *NebulaCipherState
  20. H *noise.HandshakeState
  21. myCert cert.Certificate
  22. peerCert *cert.CachedCertificate
  23. initiator bool
  24. messageCounter atomic.Uint64
  25. window *Bits
  26. writeLock sync.Mutex
  27. }
  28. func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
  29. var dhFunc noise.DHFunc
  30. switch crt.Curve() {
  31. case cert.Curve_CURVE25519:
  32. dhFunc = noise.DH25519
  33. case cert.Curve_P256:
  34. if cs.pkcs11Backed {
  35. dhFunc = noiseutil.DHP256PKCS11
  36. } else {
  37. dhFunc = noiseutil.DHP256
  38. }
  39. default:
  40. return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
  41. }
  42. var ncs noise.CipherSuite
  43. if cs.cipher == "chachapoly" {
  44. ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
  45. } else {
  46. ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
  47. }
  48. static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
  49. b := NewBits(ReplayWindow)
  50. // Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
  51. b.Update(l, 0)
  52. hs, err := noise.NewHandshakeState(noise.Config{
  53. CipherSuite: ncs,
  54. Random: rand.Reader,
  55. Pattern: pattern,
  56. Initiator: initiator,
  57. StaticKeypair: static,
  58. //NOTE: These should come from CertState (pki.go) when we finally implement it
  59. PresharedKey: []byte{},
  60. PresharedKeyPlacement: 0,
  61. })
  62. if err != nil {
  63. return nil, fmt.Errorf("NewConnectionState: %s", err)
  64. }
  65. // The queue and ready params prevent a counter race that would happen when
  66. // sending stored packets and simultaneously accepting new traffic.
  67. ci := &ConnectionState{
  68. H: hs,
  69. initiator: initiator,
  70. window: b,
  71. myCert: crt,
  72. }
  73. // always start the counter from 2, as packet 1 and packet 2 are handshake packets.
  74. ci.messageCounter.Add(2)
  75. return ci, nil
  76. }
  77. func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
  78. return json.Marshal(m{
  79. "certificate": cs.peerCert,
  80. "initiator": cs.initiator,
  81. "message_counter": cs.messageCounter.Load(),
  82. })
  83. }
  84. func (cs *ConnectionState) Curve() cert.Curve {
  85. return cs.myCert.Curve()
  86. }