cert_v1.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. package cert
  2. import (
  3. "bytes"
  4. "crypto/ecdh"
  5. "crypto/ecdsa"
  6. "crypto/ed25519"
  7. "crypto/elliptic"
  8. "crypto/sha256"
  9. "encoding/binary"
  10. "encoding/hex"
  11. "encoding/json"
  12. "encoding/pem"
  13. "fmt"
  14. "net"
  15. "net/netip"
  16. "time"
  17. "golang.org/x/crypto/curve25519"
  18. "google.golang.org/protobuf/proto"
  19. )
  20. const publicKeyLen = 32
  21. type certificateV1 struct {
  22. details detailsV1
  23. signature []byte
  24. }
  25. type detailsV1 struct {
  26. name string
  27. networks []netip.Prefix
  28. unsafeNetworks []netip.Prefix
  29. groups []string
  30. notBefore time.Time
  31. notAfter time.Time
  32. publicKey []byte
  33. isCA bool
  34. issuer string
  35. curve Curve
  36. }
  37. type m = map[string]any
  38. func (c *certificateV1) Version() Version {
  39. return Version1
  40. }
  41. func (c *certificateV1) Curve() Curve {
  42. return c.details.curve
  43. }
  44. func (c *certificateV1) Groups() []string {
  45. return c.details.groups
  46. }
  47. func (c *certificateV1) IsCA() bool {
  48. return c.details.isCA
  49. }
  50. func (c *certificateV1) Issuer() string {
  51. return c.details.issuer
  52. }
  53. func (c *certificateV1) Name() string {
  54. return c.details.name
  55. }
  56. func (c *certificateV1) Networks() []netip.Prefix {
  57. return c.details.networks
  58. }
  59. func (c *certificateV1) NotAfter() time.Time {
  60. return c.details.notAfter
  61. }
  62. func (c *certificateV1) NotBefore() time.Time {
  63. return c.details.notBefore
  64. }
  65. func (c *certificateV1) PublicKey() []byte {
  66. return c.details.publicKey
  67. }
  68. func (c *certificateV1) MarshalPublicKeyPEM() []byte {
  69. return marshalCertPublicKeyToPEM(c)
  70. }
  71. func (c *certificateV1) Signature() []byte {
  72. return c.signature
  73. }
  74. func (c *certificateV1) UnsafeNetworks() []netip.Prefix {
  75. return c.details.unsafeNetworks
  76. }
  77. func (c *certificateV1) Fingerprint() (string, error) {
  78. b, err := c.Marshal()
  79. if err != nil {
  80. return "", err
  81. }
  82. sum := sha256.Sum256(b)
  83. return hex.EncodeToString(sum[:]), nil
  84. }
  85. func (c *certificateV1) CheckSignature(key []byte) bool {
  86. b, err := proto.Marshal(c.getRawDetails())
  87. if err != nil {
  88. return false
  89. }
  90. switch c.details.curve {
  91. case Curve_CURVE25519:
  92. return ed25519.Verify(key, b, c.signature)
  93. case Curve_P256:
  94. pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
  95. if err != nil {
  96. return false
  97. }
  98. hashed := sha256.Sum256(b)
  99. return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
  100. default:
  101. return false
  102. }
  103. }
  104. func (c *certificateV1) Expired(t time.Time) bool {
  105. return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
  106. }
  107. func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
  108. if curve != c.details.curve {
  109. return fmt.Errorf("curve in cert and private key supplied don't match")
  110. }
  111. if c.details.isCA {
  112. switch curve {
  113. case Curve_CURVE25519:
  114. // the call to PublicKey below will panic slice bounds out of range otherwise
  115. if len(key) != ed25519.PrivateKeySize {
  116. return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
  117. }
  118. if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
  119. return fmt.Errorf("public key in cert and private key supplied don't match")
  120. }
  121. case Curve_P256:
  122. privkey, err := ecdh.P256().NewPrivateKey(key)
  123. if err != nil {
  124. return fmt.Errorf("cannot parse private key as P256: %w", err)
  125. }
  126. pub := privkey.PublicKey().Bytes()
  127. if !bytes.Equal(pub, c.details.publicKey) {
  128. return fmt.Errorf("public key in cert and private key supplied don't match")
  129. }
  130. default:
  131. return fmt.Errorf("invalid curve: %s", curve)
  132. }
  133. return nil
  134. }
  135. var pub []byte
  136. switch curve {
  137. case Curve_CURVE25519:
  138. var err error
  139. pub, err = curve25519.X25519(key, curve25519.Basepoint)
  140. if err != nil {
  141. return err
  142. }
  143. case Curve_P256:
  144. privkey, err := ecdh.P256().NewPrivateKey(key)
  145. if err != nil {
  146. return err
  147. }
  148. pub = privkey.PublicKey().Bytes()
  149. default:
  150. return fmt.Errorf("invalid curve: %s", curve)
  151. }
  152. if !bytes.Equal(pub, c.details.publicKey) {
  153. return fmt.Errorf("public key in cert and private key supplied don't match")
  154. }
  155. return nil
  156. }
  157. // getRawDetails marshals the raw details into protobuf ready struct
  158. func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
  159. rd := &RawNebulaCertificateDetails{
  160. Name: c.details.name,
  161. Groups: c.details.groups,
  162. NotBefore: c.details.notBefore.Unix(),
  163. NotAfter: c.details.notAfter.Unix(),
  164. PublicKey: make([]byte, len(c.details.publicKey)),
  165. IsCA: c.details.isCA,
  166. Curve: c.details.curve,
  167. }
  168. for _, ipNet := range c.details.networks {
  169. mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
  170. rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask))
  171. }
  172. for _, ipNet := range c.details.unsafeNetworks {
  173. mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
  174. rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask))
  175. }
  176. copy(rd.PublicKey, c.details.publicKey[:])
  177. // I know, this is terrible
  178. rd.Issuer, _ = hex.DecodeString(c.details.issuer)
  179. return rd
  180. }
  181. func (c *certificateV1) String() string {
  182. b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
  183. if err != nil {
  184. return fmt.Sprintf("<error marshalling certificate: %v>", err)
  185. }
  186. return string(b)
  187. }
  188. func (c *certificateV1) MarshalForHandshakes() ([]byte, error) {
  189. pubKey := c.details.publicKey
  190. c.details.publicKey = nil
  191. rawCertNoKey, err := c.Marshal()
  192. if err != nil {
  193. return nil, err
  194. }
  195. c.details.publicKey = pubKey
  196. return rawCertNoKey, nil
  197. }
  198. func (c *certificateV1) Marshal() ([]byte, error) {
  199. rc := RawNebulaCertificate{
  200. Details: c.getRawDetails(),
  201. Signature: c.signature,
  202. }
  203. return proto.Marshal(&rc)
  204. }
  205. func (c *certificateV1) MarshalPEM() ([]byte, error) {
  206. b, err := c.Marshal()
  207. if err != nil {
  208. return nil, err
  209. }
  210. return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil
  211. }
  212. func (c *certificateV1) MarshalJSON() ([]byte, error) {
  213. return json.Marshal(c.marshalJSON())
  214. }
  215. func (c *certificateV1) marshalJSON() m {
  216. fp, _ := c.Fingerprint()
  217. return m{
  218. "version": Version1,
  219. "details": m{
  220. "name": c.details.name,
  221. "networks": c.details.networks,
  222. "unsafeNetworks": c.details.unsafeNetworks,
  223. "groups": c.details.groups,
  224. "notBefore": c.details.notBefore,
  225. "notAfter": c.details.notAfter,
  226. "publicKey": fmt.Sprintf("%x", c.details.publicKey),
  227. "isCa": c.details.isCA,
  228. "issuer": c.details.issuer,
  229. "curve": c.details.curve.String(),
  230. },
  231. "fingerprint": fp,
  232. "signature": fmt.Sprintf("%x", c.Signature()),
  233. }
  234. }
  235. func (c *certificateV1) Copy() Certificate {
  236. nc := &certificateV1{
  237. details: detailsV1{
  238. name: c.details.name,
  239. notBefore: c.details.notBefore,
  240. notAfter: c.details.notAfter,
  241. publicKey: make([]byte, len(c.details.publicKey)),
  242. isCA: c.details.isCA,
  243. issuer: c.details.issuer,
  244. curve: c.details.curve,
  245. },
  246. signature: make([]byte, len(c.signature)),
  247. }
  248. if c.details.groups != nil {
  249. nc.details.groups = make([]string, len(c.details.groups))
  250. copy(nc.details.groups, c.details.groups)
  251. }
  252. if c.details.networks != nil {
  253. nc.details.networks = make([]netip.Prefix, len(c.details.networks))
  254. copy(nc.details.networks, c.details.networks)
  255. }
  256. if c.details.unsafeNetworks != nil {
  257. nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
  258. copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
  259. }
  260. copy(nc.signature, c.signature)
  261. copy(nc.details.publicKey, c.details.publicKey)
  262. return nc
  263. }
  264. func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error {
  265. c.details = detailsV1{
  266. name: t.Name,
  267. networks: t.Networks,
  268. unsafeNetworks: t.UnsafeNetworks,
  269. groups: t.Groups,
  270. notBefore: t.NotBefore,
  271. notAfter: t.NotAfter,
  272. publicKey: t.PublicKey,
  273. isCA: t.IsCA,
  274. curve: t.Curve,
  275. issuer: t.issuer,
  276. }
  277. return c.validate()
  278. }
  279. func (c *certificateV1) validate() error {
  280. // Empty names are allowed
  281. if len(c.details.publicKey) == 0 {
  282. return ErrInvalidPublicKey
  283. }
  284. // Original v1 rules allowed multiple networks to be present but ignored all but the first one.
  285. // Continue to allow this behavior
  286. if !c.details.isCA && len(c.details.networks) == 0 {
  287. return NewErrInvalidCertificateProperties("non-CA certificates must contain exactly one network")
  288. }
  289. for _, network := range c.details.networks {
  290. if !network.IsValid() || !network.Addr().IsValid() {
  291. return NewErrInvalidCertificateProperties("invalid network: %s", network)
  292. }
  293. if network.Addr().Is6() {
  294. return NewErrInvalidCertificateProperties("certificate may not contain IPv6 networks: %v", network)
  295. }
  296. if network.Addr().IsUnspecified() {
  297. return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
  298. }
  299. if network.Addr().Zone() != "" {
  300. return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
  301. }
  302. }
  303. for _, network := range c.details.unsafeNetworks {
  304. if !network.IsValid() || !network.Addr().IsValid() {
  305. return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
  306. }
  307. if network.Addr().Is6() {
  308. return NewErrInvalidCertificateProperties("certificate may not contain IPv6 unsafe networks: %v", network)
  309. }
  310. if network.Addr().Zone() != "" {
  311. return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
  312. }
  313. }
  314. // v1 doesn't bother with sort order or uniqueness of networks or unsafe networks.
  315. // We can't modify the unmarshalled data because verification requires re-marshalling and a re-ordered
  316. // unsafe networks would result in a different signature.
  317. return nil
  318. }
  319. func (c *certificateV1) marshalForSigning() ([]byte, error) {
  320. b, err := proto.Marshal(c.getRawDetails())
  321. if err != nil {
  322. return nil, err
  323. }
  324. return b, nil
  325. }
  326. func (c *certificateV1) setSignature(b []byte) error {
  327. if len(b) == 0 {
  328. return ErrEmptySignature
  329. }
  330. c.signature = b
  331. return nil
  332. }
  333. // unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert
  334. // if the publicKey is provided here then it is not required to be present in `b`
  335. func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) {
  336. if len(b) == 0 {
  337. return nil, fmt.Errorf("nil byte array")
  338. }
  339. var rc RawNebulaCertificate
  340. err := proto.Unmarshal(b, &rc)
  341. if err != nil {
  342. return nil, err
  343. }
  344. if rc.Details == nil {
  345. return nil, fmt.Errorf("encoded Details was nil")
  346. }
  347. if len(rc.Details.Ips)%2 != 0 {
  348. return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found")
  349. }
  350. if len(rc.Details.Subnets)%2 != 0 {
  351. return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found")
  352. }
  353. nc := certificateV1{
  354. details: detailsV1{
  355. name: rc.Details.Name,
  356. groups: make([]string, len(rc.Details.Groups)),
  357. networks: make([]netip.Prefix, len(rc.Details.Ips)/2),
  358. unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2),
  359. notBefore: time.Unix(rc.Details.NotBefore, 0),
  360. notAfter: time.Unix(rc.Details.NotAfter, 0),
  361. publicKey: make([]byte, len(rc.Details.PublicKey)),
  362. isCA: rc.Details.IsCA,
  363. curve: rc.Details.Curve,
  364. },
  365. signature: make([]byte, len(rc.Signature)),
  366. }
  367. copy(nc.signature, rc.Signature)
  368. copy(nc.details.groups, rc.Details.Groups)
  369. nc.details.issuer = hex.EncodeToString(rc.Details.Issuer)
  370. if len(publicKey) > 0 {
  371. nc.details.publicKey = publicKey
  372. }
  373. copy(nc.details.publicKey, rc.Details.PublicKey)
  374. var ip netip.Addr
  375. for i, rawIp := range rc.Details.Ips {
  376. if i%2 == 0 {
  377. ip = int2addr(rawIp)
  378. } else {
  379. ones, _ := net.IPMask(int2ip(rawIp)).Size()
  380. nc.details.networks[i/2] = netip.PrefixFrom(ip, ones)
  381. }
  382. }
  383. for i, rawIp := range rc.Details.Subnets {
  384. if i%2 == 0 {
  385. ip = int2addr(rawIp)
  386. } else {
  387. ones, _ := net.IPMask(int2ip(rawIp)).Size()
  388. nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones)
  389. }
  390. }
  391. err = nc.validate()
  392. if err != nil {
  393. return nil, err
  394. }
  395. return &nc, nil
  396. }
  397. func ip2int(ip []byte) uint32 {
  398. if len(ip) == 16 {
  399. return binary.BigEndian.Uint32(ip[12:16])
  400. }
  401. return binary.BigEndian.Uint32(ip)
  402. }
  403. func int2ip(nn uint32) net.IP {
  404. ip := make(net.IP, net.IPv4len)
  405. binary.BigEndian.PutUint32(ip, nn)
  406. return ip
  407. }
  408. func addr2int(addr netip.Addr) uint32 {
  409. b := addr.Unmap().As4()
  410. return binary.BigEndian.Uint32(b[:])
  411. }
  412. func int2addr(nn uint32) netip.Addr {
  413. ip := [4]byte{}
  414. binary.BigEndian.PutUint32(ip[:], nn)
  415. return netip.AddrFrom4(ip).Unmap()
  416. }