connection_state.go 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. package nebula
  2. import (
  3. "crypto/rand"
  4. "encoding/json"
  5. "fmt"
  6. "sync"
  7. "sync/atomic"
  8. "github.com/flynn/noise"
  9. "github.com/sirupsen/logrus"
  10. "github.com/slackhq/nebula/cert"
  11. "github.com/slackhq/nebula/noiseutil"
  12. )
  13. // TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets.
  14. // 4092 should be sufficient for 5Gbps
  15. const ReplayWindow = 1024
  16. type ConnectionState struct {
  17. eKey *NebulaCipherState
  18. dKey *NebulaCipherState
  19. H *noise.HandshakeState
  20. myCert cert.Certificate
  21. peerCert *cert.CachedCertificate
  22. initiator bool
  23. messageCounter atomic.Uint64
  24. window *Bits
  25. writeLock sync.Mutex
  26. }
  27. func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
  28. var dhFunc noise.DHFunc
  29. switch crt.Curve() {
  30. case cert.Curve_CURVE25519:
  31. dhFunc = noise.DH25519
  32. case cert.Curve_P256:
  33. if cs.pkcs11Backed {
  34. dhFunc = noiseutil.DHP256PKCS11
  35. } else {
  36. dhFunc = noiseutil.DHP256
  37. }
  38. default:
  39. return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
  40. }
  41. var ncs noise.CipherSuite
  42. if cs.cipher == "chachapoly" {
  43. ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
  44. } else {
  45. ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
  46. }
  47. static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
  48. b := NewBits(ReplayWindow)
  49. // Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
  50. b.Update(l, 0)
  51. hs, err := noise.NewHandshakeState(noise.Config{
  52. CipherSuite: ncs,
  53. Random: rand.Reader,
  54. Pattern: pattern,
  55. Initiator: initiator,
  56. StaticKeypair: static,
  57. //NOTE: These should come from CertState (pki.go) when we finally implement it
  58. PresharedKey: []byte{},
  59. PresharedKeyPlacement: 0,
  60. })
  61. if err != nil {
  62. return nil, fmt.Errorf("NewConnectionState: %s", err)
  63. }
  64. // The queue and ready params prevent a counter race that would happen when
  65. // sending stored packets and simultaneously accepting new traffic.
  66. ci := &ConnectionState{
  67. H: hs,
  68. initiator: initiator,
  69. window: b,
  70. myCert: crt,
  71. }
  72. // always start the counter from 2, as packet 1 and packet 2 are handshake packets.
  73. ci.messageCounter.Add(2)
  74. return ci, nil
  75. }
  76. func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
  77. return json.Marshal(m{
  78. "certificate": cs.peerCert,
  79. "initiator": cs.initiator,
  80. "message_counter": cs.messageCounter.Load(),
  81. })
  82. }
  83. func (cs *ConnectionState) Curve() cert.Curve {
  84. return cs.myCert.Curve()
  85. }