cert_v1.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. package cert
  2. import (
  3. "bytes"
  4. "crypto/ecdh"
  5. "crypto/ecdsa"
  6. "crypto/ed25519"
  7. "crypto/elliptic"
  8. "crypto/sha256"
  9. "encoding/binary"
  10. "encoding/hex"
  11. "encoding/json"
  12. "encoding/pem"
  13. "fmt"
  14. "net"
  15. "net/netip"
  16. "time"
  17. "golang.org/x/crypto/curve25519"
  18. "google.golang.org/protobuf/proto"
  19. )
  20. const publicKeyLen = 32
  21. type certificateV1 struct {
  22. details detailsV1
  23. signature []byte
  24. }
  25. type detailsV1 struct {
  26. name string
  27. networks []netip.Prefix
  28. unsafeNetworks []netip.Prefix
  29. groups []string
  30. notBefore time.Time
  31. notAfter time.Time
  32. publicKey []byte
  33. isCA bool
  34. issuer string
  35. curve Curve
  36. }
  37. type m = map[string]any
  38. func (c *certificateV1) Version() Version {
  39. return Version1
  40. }
  41. func (c *certificateV1) Curve() Curve {
  42. return c.details.curve
  43. }
  44. func (c *certificateV1) Groups() []string {
  45. return c.details.groups
  46. }
  47. func (c *certificateV1) IsCA() bool {
  48. return c.details.isCA
  49. }
  50. func (c *certificateV1) Issuer() string {
  51. return c.details.issuer
  52. }
  53. func (c *certificateV1) Name() string {
  54. return c.details.name
  55. }
  56. func (c *certificateV1) Networks() []netip.Prefix {
  57. return c.details.networks
  58. }
  59. func (c *certificateV1) NotAfter() time.Time {
  60. return c.details.notAfter
  61. }
  62. func (c *certificateV1) NotBefore() time.Time {
  63. return c.details.notBefore
  64. }
  65. func (c *certificateV1) PublicKey() []byte {
  66. return c.details.publicKey
  67. }
  68. func (c *certificateV1) Signature() []byte {
  69. return c.signature
  70. }
  71. func (c *certificateV1) UnsafeNetworks() []netip.Prefix {
  72. return c.details.unsafeNetworks
  73. }
  74. func (c *certificateV1) Fingerprint() (string, error) {
  75. b, err := c.Marshal()
  76. if err != nil {
  77. return "", err
  78. }
  79. sum := sha256.Sum256(b)
  80. return hex.EncodeToString(sum[:]), nil
  81. }
  82. func (c *certificateV1) CheckSignature(key []byte) bool {
  83. b, err := proto.Marshal(c.getRawDetails())
  84. if err != nil {
  85. return false
  86. }
  87. switch c.details.curve {
  88. case Curve_CURVE25519:
  89. return ed25519.Verify(key, b, c.signature)
  90. case Curve_P256:
  91. x, y := elliptic.Unmarshal(elliptic.P256(), key)
  92. pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
  93. hashed := sha256.Sum256(b)
  94. return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
  95. default:
  96. return false
  97. }
  98. }
  99. func (c *certificateV1) Expired(t time.Time) bool {
  100. return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
  101. }
  102. func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
  103. if curve != c.details.curve {
  104. return fmt.Errorf("curve in cert and private key supplied don't match")
  105. }
  106. if c.details.isCA {
  107. switch curve {
  108. case Curve_CURVE25519:
  109. // the call to PublicKey below will panic slice bounds out of range otherwise
  110. if len(key) != ed25519.PrivateKeySize {
  111. return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
  112. }
  113. if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
  114. return fmt.Errorf("public key in cert and private key supplied don't match")
  115. }
  116. case Curve_P256:
  117. privkey, err := ecdh.P256().NewPrivateKey(key)
  118. if err != nil {
  119. return fmt.Errorf("cannot parse private key as P256: %w", err)
  120. }
  121. pub := privkey.PublicKey().Bytes()
  122. if !bytes.Equal(pub, c.details.publicKey) {
  123. return fmt.Errorf("public key in cert and private key supplied don't match")
  124. }
  125. default:
  126. return fmt.Errorf("invalid curve: %s", curve)
  127. }
  128. return nil
  129. }
  130. var pub []byte
  131. switch curve {
  132. case Curve_CURVE25519:
  133. var err error
  134. pub, err = curve25519.X25519(key, curve25519.Basepoint)
  135. if err != nil {
  136. return err
  137. }
  138. case Curve_P256:
  139. privkey, err := ecdh.P256().NewPrivateKey(key)
  140. if err != nil {
  141. return err
  142. }
  143. pub = privkey.PublicKey().Bytes()
  144. default:
  145. return fmt.Errorf("invalid curve: %s", curve)
  146. }
  147. if !bytes.Equal(pub, c.details.publicKey) {
  148. return fmt.Errorf("public key in cert and private key supplied don't match")
  149. }
  150. return nil
  151. }
  152. // getRawDetails marshals the raw details into protobuf ready struct
  153. func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
  154. rd := &RawNebulaCertificateDetails{
  155. Name: c.details.name,
  156. Groups: c.details.groups,
  157. NotBefore: c.details.notBefore.Unix(),
  158. NotAfter: c.details.notAfter.Unix(),
  159. PublicKey: make([]byte, len(c.details.publicKey)),
  160. IsCA: c.details.isCA,
  161. Curve: c.details.curve,
  162. }
  163. for _, ipNet := range c.details.networks {
  164. mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
  165. rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask))
  166. }
  167. for _, ipNet := range c.details.unsafeNetworks {
  168. mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
  169. rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask))
  170. }
  171. copy(rd.PublicKey, c.details.publicKey[:])
  172. // I know, this is terrible
  173. rd.Issuer, _ = hex.DecodeString(c.details.issuer)
  174. return rd
  175. }
  176. func (c *certificateV1) String() string {
  177. b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
  178. if err != nil {
  179. return fmt.Sprintf("<error marshalling certificate: %v>", err)
  180. }
  181. return string(b)
  182. }
  183. func (c *certificateV1) MarshalForHandshakes() ([]byte, error) {
  184. pubKey := c.details.publicKey
  185. c.details.publicKey = nil
  186. rawCertNoKey, err := c.Marshal()
  187. if err != nil {
  188. return nil, err
  189. }
  190. c.details.publicKey = pubKey
  191. return rawCertNoKey, nil
  192. }
  193. func (c *certificateV1) Marshal() ([]byte, error) {
  194. rc := RawNebulaCertificate{
  195. Details: c.getRawDetails(),
  196. Signature: c.signature,
  197. }
  198. return proto.Marshal(&rc)
  199. }
  200. func (c *certificateV1) MarshalPEM() ([]byte, error) {
  201. b, err := c.Marshal()
  202. if err != nil {
  203. return nil, err
  204. }
  205. return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil
  206. }
  207. func (c *certificateV1) MarshalJSON() ([]byte, error) {
  208. return json.Marshal(c.marshalJSON())
  209. }
  210. func (c *certificateV1) marshalJSON() m {
  211. fp, _ := c.Fingerprint()
  212. return m{
  213. "version": Version1,
  214. "details": m{
  215. "name": c.details.name,
  216. "networks": c.details.networks,
  217. "unsafeNetworks": c.details.unsafeNetworks,
  218. "groups": c.details.groups,
  219. "notBefore": c.details.notBefore,
  220. "notAfter": c.details.notAfter,
  221. "publicKey": fmt.Sprintf("%x", c.details.publicKey),
  222. "isCa": c.details.isCA,
  223. "issuer": c.details.issuer,
  224. "curve": c.details.curve.String(),
  225. },
  226. "fingerprint": fp,
  227. "signature": fmt.Sprintf("%x", c.Signature()),
  228. }
  229. }
  230. func (c *certificateV1) Copy() Certificate {
  231. nc := &certificateV1{
  232. details: detailsV1{
  233. name: c.details.name,
  234. notBefore: c.details.notBefore,
  235. notAfter: c.details.notAfter,
  236. publicKey: make([]byte, len(c.details.publicKey)),
  237. isCA: c.details.isCA,
  238. issuer: c.details.issuer,
  239. curve: c.details.curve,
  240. },
  241. signature: make([]byte, len(c.signature)),
  242. }
  243. if c.details.groups != nil {
  244. nc.details.groups = make([]string, len(c.details.groups))
  245. copy(nc.details.groups, c.details.groups)
  246. }
  247. if c.details.networks != nil {
  248. nc.details.networks = make([]netip.Prefix, len(c.details.networks))
  249. copy(nc.details.networks, c.details.networks)
  250. }
  251. if c.details.unsafeNetworks != nil {
  252. nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
  253. copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
  254. }
  255. copy(nc.signature, c.signature)
  256. copy(nc.details.publicKey, c.details.publicKey)
  257. return nc
  258. }
  259. func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error {
  260. c.details = detailsV1{
  261. name: t.Name,
  262. networks: t.Networks,
  263. unsafeNetworks: t.UnsafeNetworks,
  264. groups: t.Groups,
  265. notBefore: t.NotBefore,
  266. notAfter: t.NotAfter,
  267. publicKey: t.PublicKey,
  268. isCA: t.IsCA,
  269. curve: t.Curve,
  270. issuer: t.issuer,
  271. }
  272. return c.validate()
  273. }
  274. func (c *certificateV1) validate() error {
  275. // Empty names are allowed
  276. if len(c.details.publicKey) == 0 {
  277. return ErrInvalidPublicKey
  278. }
  279. // Original v1 rules allowed multiple networks to be present but ignored all but the first one.
  280. // Continue to allow this behavior
  281. if !c.details.isCA && len(c.details.networks) == 0 {
  282. return NewErrInvalidCertificateProperties("non-CA certificates must contain exactly one network")
  283. }
  284. for _, network := range c.details.networks {
  285. if !network.IsValid() || !network.Addr().IsValid() {
  286. return NewErrInvalidCertificateProperties("invalid network: %s", network)
  287. }
  288. if network.Addr().Is6() {
  289. return NewErrInvalidCertificateProperties("certificate may not contain IPv6 networks: %v", network)
  290. }
  291. if network.Addr().IsUnspecified() {
  292. return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
  293. }
  294. if network.Addr().Zone() != "" {
  295. return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
  296. }
  297. }
  298. for _, network := range c.details.unsafeNetworks {
  299. if !network.IsValid() || !network.Addr().IsValid() {
  300. return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
  301. }
  302. if network.Addr().Is6() {
  303. return NewErrInvalidCertificateProperties("certificate may not contain IPv6 unsafe networks: %v", network)
  304. }
  305. if network.Addr().Zone() != "" {
  306. return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
  307. }
  308. }
  309. // v1 doesn't bother with sort order or uniqueness of networks or unsafe networks.
  310. // We can't modify the unmarshalled data because verification requires re-marshalling and a re-ordered
  311. // unsafe networks would result in a different signature.
  312. return nil
  313. }
  314. func (c *certificateV1) marshalForSigning() ([]byte, error) {
  315. b, err := proto.Marshal(c.getRawDetails())
  316. if err != nil {
  317. return nil, err
  318. }
  319. return b, nil
  320. }
  321. func (c *certificateV1) setSignature(b []byte) error {
  322. if len(b) == 0 {
  323. return ErrEmptySignature
  324. }
  325. c.signature = b
  326. return nil
  327. }
  328. // unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert
  329. // if the publicKey is provided here then it is not required to be present in `b`
  330. func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) {
  331. if len(b) == 0 {
  332. return nil, fmt.Errorf("nil byte array")
  333. }
  334. var rc RawNebulaCertificate
  335. err := proto.Unmarshal(b, &rc)
  336. if err != nil {
  337. return nil, err
  338. }
  339. if rc.Details == nil {
  340. return nil, fmt.Errorf("encoded Details was nil")
  341. }
  342. if len(rc.Details.Ips)%2 != 0 {
  343. return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found")
  344. }
  345. if len(rc.Details.Subnets)%2 != 0 {
  346. return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found")
  347. }
  348. nc := certificateV1{
  349. details: detailsV1{
  350. name: rc.Details.Name,
  351. groups: make([]string, len(rc.Details.Groups)),
  352. networks: make([]netip.Prefix, len(rc.Details.Ips)/2),
  353. unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2),
  354. notBefore: time.Unix(rc.Details.NotBefore, 0),
  355. notAfter: time.Unix(rc.Details.NotAfter, 0),
  356. publicKey: make([]byte, len(rc.Details.PublicKey)),
  357. isCA: rc.Details.IsCA,
  358. curve: rc.Details.Curve,
  359. },
  360. signature: make([]byte, len(rc.Signature)),
  361. }
  362. copy(nc.signature, rc.Signature)
  363. copy(nc.details.groups, rc.Details.Groups)
  364. nc.details.issuer = hex.EncodeToString(rc.Details.Issuer)
  365. if len(publicKey) > 0 {
  366. nc.details.publicKey = publicKey
  367. }
  368. copy(nc.details.publicKey, rc.Details.PublicKey)
  369. var ip netip.Addr
  370. for i, rawIp := range rc.Details.Ips {
  371. if i%2 == 0 {
  372. ip = int2addr(rawIp)
  373. } else {
  374. ones, _ := net.IPMask(int2ip(rawIp)).Size()
  375. nc.details.networks[i/2] = netip.PrefixFrom(ip, ones)
  376. }
  377. }
  378. for i, rawIp := range rc.Details.Subnets {
  379. if i%2 == 0 {
  380. ip = int2addr(rawIp)
  381. } else {
  382. ones, _ := net.IPMask(int2ip(rawIp)).Size()
  383. nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones)
  384. }
  385. }
  386. err = nc.validate()
  387. if err != nil {
  388. return nil, err
  389. }
  390. return &nc, nil
  391. }
  392. func ip2int(ip []byte) uint32 {
  393. if len(ip) == 16 {
  394. return binary.BigEndian.Uint32(ip[12:16])
  395. }
  396. return binary.BigEndian.Uint32(ip)
  397. }
  398. func int2ip(nn uint32) net.IP {
  399. ip := make(net.IP, net.IPv4len)
  400. binary.BigEndian.PutUint32(ip, nn)
  401. return ip
  402. }
  403. func addr2int(addr netip.Addr) uint32 {
  404. b := addr.Unmap().As4()
  405. return binary.BigEndian.Uint32(b[:])
  406. }
  407. func int2addr(nn uint32) netip.Addr {
  408. ip := [4]byte{}
  409. binary.BigEndian.PutUint32(ip[:], nn)
  410. return netip.AddrFrom4(ip).Unmap()
  411. }