connection_state.go 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. package nebula
  2. import (
  3. "crypto/rand"
  4. "encoding/json"
  5. "fmt"
  6. "sync/atomic"
  7. "github.com/flynn/noise"
  8. "github.com/sirupsen/logrus"
  9. "github.com/slackhq/nebula/cert"
  10. "github.com/slackhq/nebula/noiseutil"
  11. )
  12. const ReplayWindow = 1024
  13. type ConnectionState struct {
  14. eKey *NebulaCipherState
  15. dKey *NebulaCipherState
  16. H *noise.HandshakeState
  17. myCert cert.Certificate
  18. peerCert *cert.CachedCertificate
  19. initiator bool
  20. messageCounter atomic.Uint64
  21. window *Bits
  22. writeLock syncMutex
  23. }
  24. func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
  25. var dhFunc noise.DHFunc
  26. switch crt.Curve() {
  27. case cert.Curve_CURVE25519:
  28. dhFunc = noise.DH25519
  29. case cert.Curve_P256:
  30. if cs.pkcs11Backed {
  31. dhFunc = noiseutil.DHP256PKCS11
  32. } else {
  33. dhFunc = noiseutil.DHP256
  34. }
  35. default:
  36. return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
  37. }
  38. var ncs noise.CipherSuite
  39. if cs.cipher == "chachapoly" {
  40. ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
  41. } else {
  42. ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
  43. }
  44. static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
  45. b := NewBits(ReplayWindow)
  46. // Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
  47. b.Update(l, 0)
  48. hs, err := noise.NewHandshakeState(noise.Config{
  49. CipherSuite: ncs,
  50. Random: rand.Reader,
  51. Pattern: pattern,
  52. Initiator: initiator,
  53. StaticKeypair: static,
  54. //NOTE: These should come from CertState (pki.go) when we finally implement it
  55. PresharedKey: []byte{},
  56. PresharedKeyPlacement: 0,
  57. })
  58. if err != nil {
  59. return nil, fmt.Errorf("NewConnectionState: %s", err)
  60. }
  61. // The queue and ready params prevent a counter race that would happen when
  62. // sending stored packets and simultaneously accepting new traffic.
  63. ci := &ConnectionState{
  64. H: hs,
  65. initiator: initiator,
  66. window: b,
  67. myCert: crt,
  68. writeLock: newSyncMutex("connection-state-write"),
  69. }
  70. // always start the counter from 2, as packet 1 and packet 2 are handshake packets.
  71. ci.messageCounter.Add(2)
  72. return ci, nil
  73. }
  74. func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
  75. return json.Marshal(m{
  76. "certificate": cs.peerCert,
  77. "initiator": cs.initiator,
  78. "message_counter": cs.messageCounter.Load(),
  79. })
  80. }
  81. func (cs *ConnectionState) Curve() cert.Curve {
  82. return cs.myCert.Curve()
  83. }