pki.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. package nebula
  2. import (
  3. "errors"
  4. "fmt"
  5. "os"
  6. "strings"
  7. "sync/atomic"
  8. "time"
  9. "github.com/sirupsen/logrus"
  10. "github.com/slackhq/nebula/cert"
  11. "github.com/slackhq/nebula/config"
  12. "github.com/slackhq/nebula/util"
  13. )
  14. type PKI struct {
  15. cs atomic.Pointer[CertState]
  16. caPool atomic.Pointer[cert.NebulaCAPool]
  17. l *logrus.Logger
  18. }
  19. type CertState struct {
  20. Certificate *cert.NebulaCertificate
  21. RawCertificate []byte
  22. RawCertificateNoKey []byte
  23. PublicKey []byte
  24. PrivateKey []byte
  25. }
  26. func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
  27. pki := &PKI{l: l}
  28. err := pki.reload(c, true)
  29. if err != nil {
  30. return nil, err
  31. }
  32. c.RegisterReloadCallback(func(c *config.C) {
  33. rErr := pki.reload(c, false)
  34. if rErr != nil {
  35. util.LogWithContextIfNeeded("Failed to reload PKI from config", rErr, l)
  36. }
  37. })
  38. return pki, nil
  39. }
  40. func (p *PKI) GetCertState() *CertState {
  41. return p.cs.Load()
  42. }
  43. func (p *PKI) GetCAPool() *cert.NebulaCAPool {
  44. return p.caPool.Load()
  45. }
  46. func (p *PKI) reload(c *config.C, initial bool) error {
  47. err := p.reloadCert(c, initial)
  48. if err != nil {
  49. if initial {
  50. return err
  51. }
  52. err.Log(p.l)
  53. }
  54. err = p.reloadCAPool(c)
  55. if err != nil {
  56. if initial {
  57. return err
  58. }
  59. err.Log(p.l)
  60. }
  61. return nil
  62. }
  63. func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
  64. cs, err := newCertStateFromConfig(c)
  65. if err != nil {
  66. return util.NewContextualError("Could not load client cert", nil, err)
  67. }
  68. if !initial {
  69. //TODO: include check for mask equality as well
  70. // did IP in cert change? if so, don't set
  71. currentCert := p.cs.Load().Certificate
  72. oldIPs := currentCert.Details.Ips
  73. newIPs := cs.Certificate.Details.Ips
  74. if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
  75. return util.NewContextualError(
  76. "IP in new cert was different from old",
  77. m{"new_ip": newIPs[0], "old_ip": oldIPs[0]},
  78. nil,
  79. )
  80. }
  81. }
  82. p.cs.Store(cs)
  83. if initial {
  84. p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate")
  85. } else {
  86. p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk")
  87. }
  88. return nil
  89. }
  90. func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
  91. caPool, err := loadCAPoolFromConfig(p.l, c)
  92. if err != nil {
  93. return util.NewContextualError("Failed to load ca from config", nil, err)
  94. }
  95. p.caPool.Store(caPool)
  96. p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
  97. return nil
  98. }
  99. func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) {
  100. // Marshal the certificate to ensure it is valid
  101. rawCertificate, err := certificate.Marshal()
  102. if err != nil {
  103. return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
  104. }
  105. publicKey := certificate.Details.PublicKey
  106. cs := &CertState{
  107. RawCertificate: rawCertificate,
  108. Certificate: certificate,
  109. PrivateKey: privateKey,
  110. PublicKey: publicKey,
  111. }
  112. cs.Certificate.Details.PublicKey = nil
  113. rawCertNoKey, err := cs.Certificate.Marshal()
  114. if err != nil {
  115. return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
  116. }
  117. cs.RawCertificateNoKey = rawCertNoKey
  118. // put public key back
  119. cs.Certificate.Details.PublicKey = cs.PublicKey
  120. return cs, nil
  121. }
  122. func newCertStateFromConfig(c *config.C) (*CertState, error) {
  123. var pemPrivateKey []byte
  124. var err error
  125. privPathOrPEM := c.GetString("pki.key", "")
  126. if privPathOrPEM == "" {
  127. return nil, errors.New("no pki.key path or PEM data provided")
  128. }
  129. if strings.Contains(privPathOrPEM, "-----BEGIN") {
  130. pemPrivateKey = []byte(privPathOrPEM)
  131. privPathOrPEM = "<inline>"
  132. } else {
  133. pemPrivateKey, err = os.ReadFile(privPathOrPEM)
  134. if err != nil {
  135. return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
  136. }
  137. }
  138. rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey)
  139. if err != nil {
  140. return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
  141. }
  142. var rawCert []byte
  143. pubPathOrPEM := c.GetString("pki.cert", "")
  144. if pubPathOrPEM == "" {
  145. return nil, errors.New("no pki.cert path or PEM data provided")
  146. }
  147. if strings.Contains(pubPathOrPEM, "-----BEGIN") {
  148. rawCert = []byte(pubPathOrPEM)
  149. pubPathOrPEM = "<inline>"
  150. } else {
  151. rawCert, err = os.ReadFile(pubPathOrPEM)
  152. if err != nil {
  153. return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err)
  154. }
  155. }
  156. nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
  157. if err != nil {
  158. return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
  159. }
  160. if nebulaCert.Expired(time.Now()) {
  161. return nil, fmt.Errorf("nebula certificate for this host is expired")
  162. }
  163. if len(nebulaCert.Details.Ips) == 0 {
  164. return nil, fmt.Errorf("no IPs encoded in certificate")
  165. }
  166. if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
  167. return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
  168. }
  169. return newCertState(nebulaCert, rawKey)
  170. }
  171. func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) {
  172. var rawCA []byte
  173. var err error
  174. caPathOrPEM := c.GetString("pki.ca", "")
  175. if caPathOrPEM == "" {
  176. return nil, errors.New("no pki.ca path or PEM data provided")
  177. }
  178. if strings.Contains(caPathOrPEM, "-----BEGIN") {
  179. rawCA = []byte(caPathOrPEM)
  180. } else {
  181. rawCA, err = os.ReadFile(caPathOrPEM)
  182. if err != nil {
  183. return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
  184. }
  185. }
  186. caPool, err := cert.NewCAPoolFromBytes(rawCA)
  187. if errors.Is(err, cert.ErrExpired) {
  188. var expired int
  189. for _, crt := range caPool.CAs {
  190. if crt.Expired(time.Now()) {
  191. expired++
  192. l.WithField("cert", crt).Warn("expired certificate present in CA pool")
  193. }
  194. }
  195. if expired >= len(caPool.CAs) {
  196. return nil, errors.New("no valid CA certificates present")
  197. }
  198. } else if err != nil {
  199. return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
  200. }
  201. for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
  202. l.WithField("fingerprint", fp).Info("Blocklisting cert")
  203. caPool.BlocklistFingerprint(fp)
  204. }
  205. return caPool, nil
  206. }