connection_state.go 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. package nebula
  2. import (
  3. "crypto/rand"
  4. "encoding/json"
  5. "fmt"
  6. "sync"
  7. "sync/atomic"
  8. "github.com/flynn/noise"
  9. "github.com/sirupsen/logrus"
  10. "github.com/slackhq/nebula/cert"
  11. "github.com/slackhq/nebula/noiseutil"
  12. )
  13. const ReplayWindow = 1024
  14. type ConnectionState struct {
  15. eKey *NebulaCipherState
  16. dKey *NebulaCipherState
  17. H *noise.HandshakeState
  18. myCert cert.Certificate
  19. peerCert *cert.CachedCertificate
  20. initiator bool
  21. messageCounter atomic.Uint64
  22. window *Bits
  23. writeLock sync.Mutex
  24. }
  25. func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
  26. var dhFunc noise.DHFunc
  27. switch crt.Curve() {
  28. case cert.Curve_CURVE25519:
  29. dhFunc = noise.DH25519
  30. case cert.Curve_P256:
  31. if cs.pkcs11Backed {
  32. dhFunc = noiseutil.DHP256PKCS11
  33. } else {
  34. dhFunc = noiseutil.DHP256
  35. }
  36. default:
  37. return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
  38. }
  39. var ncs noise.CipherSuite
  40. if cs.cipher == "chachapoly" {
  41. ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
  42. } else {
  43. ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
  44. }
  45. static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
  46. hs, err := noise.NewHandshakeState(noise.Config{
  47. CipherSuite: ncs,
  48. Random: rand.Reader,
  49. Pattern: pattern,
  50. Initiator: initiator,
  51. StaticKeypair: static,
  52. //NOTE: These should come from CertState (pki.go) when we finally implement it
  53. PresharedKey: []byte{},
  54. PresharedKeyPlacement: 0,
  55. })
  56. if err != nil {
  57. return nil, fmt.Errorf("NewConnectionState: %s", err)
  58. }
  59. // The queue and ready params prevent a counter race that would happen when
  60. // sending stored packets and simultaneously accepting new traffic.
  61. ci := &ConnectionState{
  62. H: hs,
  63. initiator: initiator,
  64. window: NewBits(ReplayWindow),
  65. myCert: crt,
  66. }
  67. // always start the counter from 2, as packet 1 and packet 2 are handshake packets.
  68. ci.messageCounter.Add(2)
  69. return ci, nil
  70. }
  71. func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
  72. return json.Marshal(m{
  73. "certificate": cs.peerCert,
  74. "initiator": cs.initiator,
  75. "message_counter": cs.messageCounter.Load(),
  76. })
  77. }
  78. func (cs *ConnectionState) Curve() cert.Curve {
  79. return cs.myCert.Curve()
  80. }