pki.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. package nebula
  2. import (
  3. "encoding/binary"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "net"
  8. "net/netip"
  9. "os"
  10. "slices"
  11. "strings"
  12. "sync/atomic"
  13. "time"
  14. "github.com/gaissmai/bart"
  15. "github.com/sirupsen/logrus"
  16. "github.com/slackhq/nebula/cert"
  17. "github.com/slackhq/nebula/config"
  18. "github.com/slackhq/nebula/util"
  19. )
  20. type PKI struct {
  21. cs atomic.Pointer[CertState]
  22. caPool atomic.Pointer[cert.CAPool]
  23. l *logrus.Logger
  24. }
  25. type CertState struct {
  26. v1Cert cert.Certificate
  27. v1HandshakeBytes []byte
  28. v2Cert cert.Certificate
  29. v2HandshakeBytes []byte
  30. defaultVersion cert.Version
  31. privateKey []byte
  32. pkcs11Backed bool
  33. cipher string
  34. psk *Psk
  35. myVpnNetworks []netip.Prefix
  36. myVpnNetworksTable *bart.Table[struct{}]
  37. myVpnAddrs []netip.Addr
  38. myVpnAddrsTable *bart.Table[struct{}]
  39. myVpnBroadcastAddrsTable *bart.Table[struct{}]
  40. }
  41. func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
  42. pki := &PKI{l: l}
  43. err := pki.reload(c, true)
  44. if err != nil {
  45. return nil, err
  46. }
  47. c.RegisterReloadCallback(func(c *config.C) {
  48. rErr := pki.reload(c, false)
  49. if rErr != nil {
  50. util.LogWithContextIfNeeded("Failed to reload PKI from config", rErr, l)
  51. }
  52. })
  53. return pki, nil
  54. }
  55. func (p *PKI) GetCAPool() *cert.CAPool {
  56. return p.caPool.Load()
  57. }
  58. func (p *PKI) getCertState() *CertState {
  59. return p.cs.Load()
  60. }
  61. func (p *PKI) reload(c *config.C, initial bool) error {
  62. err := p.reloadCerts(c, initial)
  63. if err != nil {
  64. if initial {
  65. return err
  66. }
  67. err.Log(p.l)
  68. }
  69. err = p.reloadCAPool(c)
  70. if err != nil {
  71. if initial {
  72. return err
  73. }
  74. err.Log(p.l)
  75. }
  76. return nil
  77. }
  78. func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
  79. newState, err := newCertStateFromConfig(c)
  80. if err != nil {
  81. return util.NewContextualError("Could not load client cert", nil, err)
  82. }
  83. if !initial {
  84. currentState := p.cs.Load()
  85. if newState.v1Cert != nil {
  86. if currentState.v1Cert == nil {
  87. return util.NewContextualError("v1 certificate was added, restart required", nil, err)
  88. }
  89. // did IP in cert change? if so, don't set
  90. if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
  91. return util.NewContextualError(
  92. "Networks in new cert was different from old",
  93. m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
  94. nil,
  95. )
  96. }
  97. if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
  98. return util.NewContextualError(
  99. "Curve in new cert was different from old",
  100. m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
  101. nil,
  102. )
  103. }
  104. } else if currentState.v1Cert != nil {
  105. //TODO: CERT-V2 we should be able to tear this down
  106. return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
  107. }
  108. if newState.v2Cert != nil {
  109. if currentState.v2Cert == nil {
  110. return util.NewContextualError("v2 certificate was added, restart required", nil, err)
  111. }
  112. // did IP in cert change? if so, don't set
  113. if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
  114. return util.NewContextualError(
  115. "Networks in new cert was different from old",
  116. m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
  117. nil,
  118. )
  119. }
  120. if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
  121. return util.NewContextualError(
  122. "Curve in new cert was different from old",
  123. m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
  124. nil,
  125. )
  126. }
  127. } else if currentState.v2Cert != nil {
  128. return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
  129. }
  130. // Cipher cant be hot swapped so just leave it at what it was before
  131. newState.cipher = currentState.cipher
  132. } else {
  133. newState.cipher = c.GetString("cipher", "aes")
  134. //TODO: this sucks and we should make it not a global
  135. switch newState.cipher {
  136. case "aes":
  137. noiseEndianness = binary.BigEndian
  138. case "chachapoly":
  139. noiseEndianness = binary.LittleEndian
  140. default:
  141. return util.NewContextualError(
  142. "unknown cipher",
  143. m{"cipher": newState.cipher},
  144. nil,
  145. )
  146. }
  147. }
  148. psk, err := NewPskFromConfig(c)
  149. if err != nil {
  150. return util.NewContextualError("Failed to load psk from config", nil, err)
  151. }
  152. if len(psk.keys) > 0 {
  153. p.l.WithField("pskMode", psk.mode).WithField("keysLen", len(psk.keys)).
  154. Info("pre shared keys are in use")
  155. }
  156. newState.psk = psk
  157. p.cs.Store(newState)
  158. //TODO: CERT-V2 newState needs a stringer that does json
  159. if initial {
  160. p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
  161. } else {
  162. p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk")
  163. }
  164. return nil
  165. }
  166. func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
  167. caPool, err := loadCAPoolFromConfig(p.l, c)
  168. if err != nil {
  169. return util.NewContextualError("Failed to load ca from config", nil, err)
  170. }
  171. p.caPool.Store(caPool)
  172. p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
  173. return nil
  174. }
  175. func (cs *CertState) GetDefaultCertificate() cert.Certificate {
  176. c := cs.getCertificate(cs.defaultVersion)
  177. if c == nil {
  178. panic("No default certificate found")
  179. }
  180. return c
  181. }
  182. func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
  183. switch v {
  184. case cert.Version1:
  185. return cs.v1Cert
  186. case cert.Version2:
  187. return cs.v2Cert
  188. }
  189. return nil
  190. }
  191. // getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version.
  192. // Callers must check if the return []byte is nil.
  193. func (cs *CertState) getHandshakeBytes(v cert.Version) []byte {
  194. switch v {
  195. case cert.Version1:
  196. return cs.v1HandshakeBytes
  197. case cert.Version2:
  198. return cs.v2HandshakeBytes
  199. default:
  200. return nil
  201. }
  202. }
  203. func (cs *CertState) String() string {
  204. b, err := cs.MarshalJSON()
  205. if err != nil {
  206. return fmt.Sprintf("error marshaling certificate state: %v", err)
  207. }
  208. return string(b)
  209. }
  210. func (cs *CertState) MarshalJSON() ([]byte, error) {
  211. msg := []json.RawMessage{}
  212. if cs.v1Cert != nil {
  213. b, err := cs.v1Cert.MarshalJSON()
  214. if err != nil {
  215. return nil, err
  216. }
  217. msg = append(msg, b)
  218. }
  219. if cs.v2Cert != nil {
  220. b, err := cs.v2Cert.MarshalJSON()
  221. if err != nil {
  222. return nil, err
  223. }
  224. msg = append(msg, b)
  225. }
  226. return json.Marshal(msg)
  227. }
  228. func newCertStateFromConfig(c *config.C) (*CertState, error) {
  229. var err error
  230. privPathOrPEM := c.GetString("pki.key", "")
  231. if privPathOrPEM == "" {
  232. return nil, errors.New("no pki.key path or PEM data provided")
  233. }
  234. rawKey, curve, isPkcs11, err := loadPrivateKey(privPathOrPEM)
  235. if err != nil {
  236. return nil, err
  237. }
  238. var rawCert []byte
  239. pubPathOrPEM := c.GetString("pki.cert", "")
  240. if pubPathOrPEM == "" {
  241. return nil, errors.New("no pki.cert path or PEM data provided")
  242. }
  243. if strings.Contains(pubPathOrPEM, "-----BEGIN") {
  244. rawCert = []byte(pubPathOrPEM)
  245. pubPathOrPEM = "<inline>"
  246. } else {
  247. rawCert, err = os.ReadFile(pubPathOrPEM)
  248. if err != nil {
  249. return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err)
  250. }
  251. }
  252. var crt, v1, v2 cert.Certificate
  253. for {
  254. // Load the certificate
  255. crt, rawCert, err = loadCertificate(rawCert)
  256. if err != nil {
  257. return nil, err
  258. }
  259. switch crt.Version() {
  260. case cert.Version1:
  261. if v1 != nil {
  262. return nil, fmt.Errorf("v1 certificate already found in pki.cert")
  263. }
  264. v1 = crt
  265. case cert.Version2:
  266. if v2 != nil {
  267. return nil, fmt.Errorf("v2 certificate already found in pki.cert")
  268. }
  269. v2 = crt
  270. default:
  271. return nil, fmt.Errorf("unknown certificate version %v", crt.Version())
  272. }
  273. if len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" {
  274. break
  275. }
  276. }
  277. if v1 == nil && v2 == nil {
  278. return nil, errors.New("no certificates found in pki.cert")
  279. }
  280. useDefaultVersion := uint32(1)
  281. if v1 == nil {
  282. // The only condition that requires v2 as the default is if only a v2 certificate is present
  283. // We do this to avoid having to configure it specifically in the config file
  284. useDefaultVersion = 2
  285. }
  286. rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion)
  287. var defaultVersion cert.Version
  288. switch rawDefaultVersion {
  289. case 1:
  290. if v1 == nil {
  291. return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert")
  292. }
  293. defaultVersion = cert.Version1
  294. case 2:
  295. defaultVersion = cert.Version2
  296. default:
  297. return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion)
  298. }
  299. return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey)
  300. }
  301. func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
  302. cs := CertState{
  303. privateKey: privateKey,
  304. pkcs11Backed: pkcs11backed,
  305. myVpnNetworksTable: new(bart.Table[struct{}]),
  306. myVpnAddrsTable: new(bart.Table[struct{}]),
  307. myVpnBroadcastAddrsTable: new(bart.Table[struct{}]),
  308. }
  309. if v1 != nil && v2 != nil {
  310. if !slices.Equal(v1.PublicKey(), v2.PublicKey()) {
  311. return nil, util.NewContextualError("v1 and v2 public keys are not the same, ignoring", nil, nil)
  312. }
  313. if v1.Curve() != v2.Curve() {
  314. return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
  315. }
  316. //TODO: CERT-V2 make sure v2 has v1s address
  317. cs.defaultVersion = dv
  318. }
  319. if v1 != nil {
  320. if pkcs11backed {
  321. //NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm
  322. } else {
  323. if err := v1.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
  324. return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
  325. }
  326. }
  327. v1hs, err := v1.MarshalForHandshakes()
  328. if err != nil {
  329. return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
  330. }
  331. cs.v1Cert = v1
  332. cs.v1HandshakeBytes = v1hs
  333. if cs.defaultVersion == 0 {
  334. cs.defaultVersion = cert.Version1
  335. }
  336. }
  337. if v2 != nil {
  338. if pkcs11backed {
  339. //NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm
  340. } else {
  341. if err := v2.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
  342. return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
  343. }
  344. }
  345. v2hs, err := v2.MarshalForHandshakes()
  346. if err != nil {
  347. return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
  348. }
  349. cs.v2Cert = v2
  350. cs.v2HandshakeBytes = v2hs
  351. if cs.defaultVersion == 0 {
  352. cs.defaultVersion = cert.Version2
  353. }
  354. }
  355. var crt cert.Certificate
  356. crt = cs.getCertificate(cert.Version2)
  357. if crt == nil {
  358. // v2 certificates are a superset, only look at v1 if its all we have
  359. crt = cs.getCertificate(cert.Version1)
  360. }
  361. for _, network := range crt.Networks() {
  362. cs.myVpnNetworks = append(cs.myVpnNetworks, network)
  363. cs.myVpnNetworksTable.Insert(network, struct{}{})
  364. cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
  365. cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{})
  366. if network.Addr().Is4() {
  367. addr := network.Masked().Addr().As4()
  368. mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
  369. binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
  370. cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{})
  371. }
  372. }
  373. return &cs, nil
  374. }
  375. func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
  376. var pemPrivateKey []byte
  377. if strings.Contains(privPathOrPEM, "-----BEGIN") {
  378. pemPrivateKey = []byte(privPathOrPEM)
  379. privPathOrPEM = "<inline>"
  380. rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
  381. if err != nil {
  382. return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
  383. }
  384. } else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
  385. rawKey = []byte(privPathOrPEM)
  386. return rawKey, cert.Curve_P256, true, nil
  387. } else {
  388. pemPrivateKey, err = os.ReadFile(privPathOrPEM)
  389. if err != nil {
  390. return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
  391. }
  392. rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
  393. if err != nil {
  394. return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
  395. }
  396. }
  397. return
  398. }
  399. func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
  400. c, b, err := cert.UnmarshalCertificateFromPEM(b)
  401. if err != nil {
  402. return nil, b, fmt.Errorf("error while unmarshaling pki.cert: %w", err)
  403. }
  404. if c.Expired(time.Now()) {
  405. return nil, b, fmt.Errorf("nebula certificate for this host is expired")
  406. }
  407. if len(c.Networks()) == 0 {
  408. return nil, b, fmt.Errorf("no networks encoded in certificate")
  409. }
  410. if c.IsCA() {
  411. return nil, b, fmt.Errorf("host certificate is a CA certificate")
  412. }
  413. return c, b, nil
  414. }
  415. func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
  416. var rawCA []byte
  417. var err error
  418. caPathOrPEM := c.GetString("pki.ca", "")
  419. if caPathOrPEM == "" {
  420. return nil, errors.New("no pki.ca path or PEM data provided")
  421. }
  422. if strings.Contains(caPathOrPEM, "-----BEGIN") {
  423. rawCA = []byte(caPathOrPEM)
  424. } else {
  425. rawCA, err = os.ReadFile(caPathOrPEM)
  426. if err != nil {
  427. return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
  428. }
  429. }
  430. caPool, err := cert.NewCAPoolFromPEM(rawCA)
  431. if errors.Is(err, cert.ErrExpired) {
  432. var expired int
  433. for _, crt := range caPool.CAs {
  434. if crt.Certificate.Expired(time.Now()) {
  435. expired++
  436. l.WithField("cert", crt).Warn("expired certificate present in CA pool")
  437. }
  438. }
  439. if expired >= len(caPool.CAs) {
  440. return nil, errors.New("no valid CA certificates present")
  441. }
  442. } else if err != nil {
  443. return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
  444. }
  445. for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
  446. l.WithField("fingerprint", fp).Info("Blocklisting cert")
  447. caPool.BlocklistFingerprint(fp)
  448. }
  449. return caPool, nil
  450. }