ca_pool.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. package cert
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/netip"
  6. "slices"
  7. "strings"
  8. "time"
  9. )
  10. type CAPool struct {
  11. CAs map[string]*CachedCertificate
  12. certBlocklist map[string]struct{}
  13. }
  14. // NewCAPool creates an empty CAPool
  15. func NewCAPool() *CAPool {
  16. ca := CAPool{
  17. CAs: make(map[string]*CachedCertificate),
  18. certBlocklist: make(map[string]struct{}),
  19. }
  20. return &ca
  21. }
  22. // NewCAPoolFromPEM will create a new CA pool from the provided
  23. // input bytes, which must be a PEM-encoded set of nebula certificates.
  24. // If the pool contains any expired certificates, an ErrExpired will be
  25. // returned along with the pool. The caller must handle any such errors.
  26. func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
  27. pool := NewCAPool()
  28. var err error
  29. var expired bool
  30. for {
  31. caPEMs, err = pool.AddCAFromPEM(caPEMs)
  32. if errors.Is(err, ErrExpired) {
  33. expired = true
  34. err = nil
  35. }
  36. if err != nil {
  37. return nil, err
  38. }
  39. if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
  40. break
  41. }
  42. }
  43. if expired {
  44. return pool, ErrExpired
  45. }
  46. return pool, nil
  47. }
  48. // AddCAFromPEM verifies a Nebula CA certificate and adds it to the pool.
  49. // Only the first pem encoded object will be consumed, any remaining bytes are returned.
  50. // Parsed certificates will be verified and must be a CA
  51. func (ncp *CAPool) AddCAFromPEM(pemBytes []byte) ([]byte, error) {
  52. c, pemBytes, err := UnmarshalCertificateFromPEM(pemBytes)
  53. if err != nil {
  54. return pemBytes, err
  55. }
  56. err = ncp.AddCA(c)
  57. if err != nil {
  58. return pemBytes, err
  59. }
  60. return pemBytes, nil
  61. }
  62. // AddCA verifies a Nebula CA certificate and adds it to the pool.
  63. func (ncp *CAPool) AddCA(c Certificate) error {
  64. if !c.IsCA() {
  65. return fmt.Errorf("%s: %w", c.Name(), ErrNotCA)
  66. }
  67. if !c.CheckSignature(c.PublicKey()) {
  68. return fmt.Errorf("%s: %w", c.Name(), ErrNotSelfSigned)
  69. }
  70. sum, err := c.Fingerprint()
  71. if err != nil {
  72. return fmt.Errorf("could not calculate fingerprint for provided CA; error: %w; %s", err, c.Name())
  73. }
  74. cc := &CachedCertificate{
  75. Certificate: c,
  76. Fingerprint: sum,
  77. InvertedGroups: make(map[string]struct{}),
  78. }
  79. for _, g := range c.Groups() {
  80. cc.InvertedGroups[g] = struct{}{}
  81. }
  82. ncp.CAs[sum] = cc
  83. if c.Expired(time.Now()) {
  84. return fmt.Errorf("%s: %w", c.Name(), ErrExpired)
  85. }
  86. return nil
  87. }
  88. // BlocklistFingerprint adds a cert fingerprint to the blocklist
  89. func (ncp *CAPool) BlocklistFingerprint(f string) {
  90. ncp.certBlocklist[f] = struct{}{}
  91. }
  92. // ResetCertBlocklist removes all previously blocklisted cert fingerprints
  93. func (ncp *CAPool) ResetCertBlocklist() {
  94. ncp.certBlocklist = make(map[string]struct{})
  95. }
  96. // IsBlocklisted tests the provided fingerprint against the pools blocklist.
  97. // Returns true if the fingerprint is blocked.
  98. func (ncp *CAPool) IsBlocklisted(fingerprint string) bool {
  99. if _, ok := ncp.certBlocklist[fingerprint]; ok {
  100. return true
  101. }
  102. return false
  103. }
  104. // VerifyCertificate verifies the certificate is valid and is signed by a trusted CA in the pool.
  105. // If the certificate is valid then the returned CachedCertificate can be used in subsequent verification attempts
  106. // to increase performance.
  107. func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCertificate, error) {
  108. if c == nil {
  109. return nil, fmt.Errorf("no certificate")
  110. }
  111. fp, err := c.Fingerprint()
  112. if err != nil {
  113. return nil, fmt.Errorf("could not calculate fingerprint to verify: %w", err)
  114. }
  115. signer, err := ncp.verify(c, now, fp, "")
  116. if err != nil {
  117. return nil, err
  118. }
  119. cc := CachedCertificate{
  120. Certificate: c,
  121. InvertedGroups: make(map[string]struct{}),
  122. Fingerprint: fp,
  123. signerFingerprint: signer.Fingerprint,
  124. }
  125. for _, g := range c.Groups() {
  126. cc.InvertedGroups[g] = struct{}{}
  127. }
  128. return &cc, nil
  129. }
  130. // VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and
  131. // is a cheaper operation to perform as a result.
  132. func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error {
  133. _, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint)
  134. return err
  135. }
  136. func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp string) (*CachedCertificate, error) {
  137. if ncp.IsBlocklisted(certFp) {
  138. return nil, ErrBlockListed
  139. }
  140. signer, err := ncp.GetCAForCert(c)
  141. if err != nil {
  142. return nil, err
  143. }
  144. if signer.Certificate.Expired(now) {
  145. return nil, ErrRootExpired
  146. }
  147. if c.Expired(now) {
  148. return nil, ErrExpired
  149. }
  150. // If we are checking a cached certificate then we can bail early here
  151. // Either the root is no longer trusted or everything is fine
  152. if len(signerFp) > 0 {
  153. if signerFp != signer.Fingerprint {
  154. return nil, ErrFingerprintMismatch
  155. }
  156. return signer, nil
  157. }
  158. if !c.CheckSignature(signer.Certificate.PublicKey()) {
  159. return nil, ErrSignatureMismatch
  160. }
  161. err = CheckCAConstraints(signer.Certificate, c)
  162. if err != nil {
  163. return nil, err
  164. }
  165. return signer, nil
  166. }
  167. // GetCAForCert attempts to return the signing certificate for the provided certificate.
  168. // No signature validation is performed
  169. func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) {
  170. issuer := c.Issuer()
  171. if issuer == "" {
  172. return nil, fmt.Errorf("no issuer in certificate")
  173. }
  174. signer, ok := ncp.CAs[issuer]
  175. if ok {
  176. return signer, nil
  177. }
  178. return nil, ErrCaNotFound
  179. }
  180. // GetFingerprints returns an array of trusted CA fingerprints
  181. func (ncp *CAPool) GetFingerprints() []string {
  182. fp := make([]string, len(ncp.CAs))
  183. i := 0
  184. for k := range ncp.CAs {
  185. fp[i] = k
  186. i++
  187. }
  188. return fp
  189. }
  190. // CheckCAConstraints returns an error if the sub certificate violates constraints present in the signer certificate.
  191. func CheckCAConstraints(signer Certificate, sub Certificate) error {
  192. return checkCAConstraints(signer, sub.NotBefore(), sub.NotAfter(), sub.Groups(), sub.Networks(), sub.UnsafeNetworks())
  193. }
  194. // checkCAConstraints is a very generic function allowing both Certificates and TBSCertificates to be tested.
  195. func checkCAConstraints(signer Certificate, notBefore, notAfter time.Time, groups []string, networks, unsafeNetworks []netip.Prefix) error {
  196. // Make sure this cert isn't valid after the root
  197. if notAfter.After(signer.NotAfter()) {
  198. return fmt.Errorf("certificate expires after signing certificate")
  199. }
  200. // Make sure this cert wasn't valid before the root
  201. if notBefore.Before(signer.NotBefore()) {
  202. return fmt.Errorf("certificate is valid before the signing certificate")
  203. }
  204. // If the signer has a limited set of groups make sure the cert only contains a subset
  205. signerGroups := signer.Groups()
  206. if len(signerGroups) > 0 {
  207. for _, g := range groups {
  208. if !slices.Contains(signerGroups, g) {
  209. return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g)
  210. }
  211. }
  212. }
  213. // If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset
  214. signingNetworks := signer.Networks()
  215. if len(signingNetworks) > 0 {
  216. for _, certNetwork := range networks {
  217. found := false
  218. for _, signingNetwork := range signingNetworks {
  219. if signingNetwork.Contains(certNetwork.Addr()) && signingNetwork.Bits() <= certNetwork.Bits() {
  220. found = true
  221. break
  222. }
  223. }
  224. if !found {
  225. return fmt.Errorf("certificate contained a network assignment outside the limitations of the signing ca: %s", certNetwork.String())
  226. }
  227. }
  228. }
  229. // If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset
  230. signingUnsafeNetworks := signer.UnsafeNetworks()
  231. if len(signingUnsafeNetworks) > 0 {
  232. for _, certUnsafeNetwork := range unsafeNetworks {
  233. found := false
  234. for _, caNetwork := range signingUnsafeNetworks {
  235. if caNetwork.Contains(certUnsafeNetwork.Addr()) && caNetwork.Bits() <= certUnsafeNetwork.Bits() {
  236. found = true
  237. break
  238. }
  239. }
  240. if !found {
  241. return fmt.Errorf("certificate contained an unsafe network assignment outside the limitations of the signing ca: %s", certUnsafeNetwork.String())
  242. }
  243. }
  244. }
  245. return nil
  246. }