cert_v1.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. package cert
  2. import (
  3. "bytes"
  4. "crypto/ecdh"
  5. "crypto/ecdsa"
  6. "crypto/ed25519"
  7. "crypto/elliptic"
  8. "crypto/sha256"
  9. "encoding/binary"
  10. "encoding/hex"
  11. "encoding/json"
  12. "encoding/pem"
  13. "fmt"
  14. "net"
  15. "net/netip"
  16. "time"
  17. "golang.org/x/crypto/curve25519"
  18. "google.golang.org/protobuf/proto"
  19. )
  20. const publicKeyLen = 32
  21. type certificateV1 struct {
  22. details detailsV1
  23. signature []byte
  24. }
  25. type detailsV1 struct {
  26. name string
  27. networks []netip.Prefix
  28. unsafeNetworks []netip.Prefix
  29. groups []string
  30. notBefore time.Time
  31. notAfter time.Time
  32. publicKey []byte
  33. isCA bool
  34. issuer string
  35. curve Curve
  36. }
  37. type m map[string]interface{}
  38. func (c *certificateV1) Version() Version {
  39. return Version1
  40. }
  41. func (c *certificateV1) Curve() Curve {
  42. return c.details.curve
  43. }
  44. func (c *certificateV1) Groups() []string {
  45. return c.details.groups
  46. }
  47. func (c *certificateV1) IsCA() bool {
  48. return c.details.isCA
  49. }
  50. func (c *certificateV1) Issuer() string {
  51. return c.details.issuer
  52. }
  53. func (c *certificateV1) Name() string {
  54. return c.details.name
  55. }
  56. func (c *certificateV1) Networks() []netip.Prefix {
  57. return c.details.networks
  58. }
  59. func (c *certificateV1) NotAfter() time.Time {
  60. return c.details.notAfter
  61. }
  62. func (c *certificateV1) NotBefore() time.Time {
  63. return c.details.notBefore
  64. }
  65. func (c *certificateV1) PublicKey() []byte {
  66. return c.details.publicKey
  67. }
  68. func (c *certificateV1) Signature() []byte {
  69. return c.signature
  70. }
  71. func (c *certificateV1) UnsafeNetworks() []netip.Prefix {
  72. return c.details.unsafeNetworks
  73. }
  74. func (c *certificateV1) Fingerprint() (string, error) {
  75. b, err := c.Marshal()
  76. if err != nil {
  77. return "", err
  78. }
  79. sum := sha256.Sum256(b)
  80. return hex.EncodeToString(sum[:]), nil
  81. }
  82. func (c *certificateV1) CheckSignature(key []byte) bool {
  83. b, err := proto.Marshal(c.getRawDetails())
  84. if err != nil {
  85. return false
  86. }
  87. switch c.details.curve {
  88. case Curve_CURVE25519:
  89. return ed25519.Verify(key, b, c.signature)
  90. case Curve_P256:
  91. x, y := elliptic.Unmarshal(elliptic.P256(), key)
  92. pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
  93. hashed := sha256.Sum256(b)
  94. return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
  95. default:
  96. return false
  97. }
  98. }
  99. func (c *certificateV1) Expired(t time.Time) bool {
  100. return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
  101. }
  102. func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
  103. if curve != c.details.curve {
  104. return fmt.Errorf("curve in cert and private key supplied don't match")
  105. }
  106. if c.details.isCA {
  107. switch curve {
  108. case Curve_CURVE25519:
  109. // the call to PublicKey below will panic slice bounds out of range otherwise
  110. if len(key) != ed25519.PrivateKeySize {
  111. return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
  112. }
  113. if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
  114. return fmt.Errorf("public key in cert and private key supplied don't match")
  115. }
  116. case Curve_P256:
  117. privkey, err := ecdh.P256().NewPrivateKey(key)
  118. if err != nil {
  119. return fmt.Errorf("cannot parse private key as P256: %w", err)
  120. }
  121. pub := privkey.PublicKey().Bytes()
  122. if !bytes.Equal(pub, c.details.publicKey) {
  123. return fmt.Errorf("public key in cert and private key supplied don't match")
  124. }
  125. default:
  126. return fmt.Errorf("invalid curve: %s", curve)
  127. }
  128. return nil
  129. }
  130. var pub []byte
  131. switch curve {
  132. case Curve_CURVE25519:
  133. var err error
  134. pub, err = curve25519.X25519(key, curve25519.Basepoint)
  135. if err != nil {
  136. return err
  137. }
  138. case Curve_P256:
  139. privkey, err := ecdh.P256().NewPrivateKey(key)
  140. if err != nil {
  141. return err
  142. }
  143. pub = privkey.PublicKey().Bytes()
  144. default:
  145. return fmt.Errorf("invalid curve: %s", curve)
  146. }
  147. if !bytes.Equal(pub, c.details.publicKey) {
  148. return fmt.Errorf("public key in cert and private key supplied don't match")
  149. }
  150. return nil
  151. }
  152. // getRawDetails marshals the raw details into protobuf ready struct
  153. func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
  154. rd := &RawNebulaCertificateDetails{
  155. Name: c.details.name,
  156. Groups: c.details.groups,
  157. NotBefore: c.details.notBefore.Unix(),
  158. NotAfter: c.details.notAfter.Unix(),
  159. PublicKey: make([]byte, len(c.details.publicKey)),
  160. IsCA: c.details.isCA,
  161. Curve: c.details.curve,
  162. }
  163. for _, ipNet := range c.details.networks {
  164. mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
  165. rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask))
  166. }
  167. for _, ipNet := range c.details.unsafeNetworks {
  168. mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
  169. rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask))
  170. }
  171. copy(rd.PublicKey, c.details.publicKey[:])
  172. // I know, this is terrible
  173. rd.Issuer, _ = hex.DecodeString(c.details.issuer)
  174. return rd
  175. }
  176. func (c *certificateV1) String() string {
  177. b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
  178. if err != nil {
  179. return fmt.Sprintf("<error marshalling certificate: %v>", err)
  180. }
  181. return string(b)
  182. }
  183. func (c *certificateV1) MarshalForHandshakes() ([]byte, error) {
  184. pubKey := c.details.publicKey
  185. c.details.publicKey = nil
  186. rawCertNoKey, err := c.Marshal()
  187. if err != nil {
  188. return nil, err
  189. }
  190. c.details.publicKey = pubKey
  191. return rawCertNoKey, nil
  192. }
  193. func (c *certificateV1) Marshal() ([]byte, error) {
  194. rc := RawNebulaCertificate{
  195. Details: c.getRawDetails(),
  196. Signature: c.signature,
  197. }
  198. return proto.Marshal(&rc)
  199. }
  200. func (c *certificateV1) MarshalPEM() ([]byte, error) {
  201. b, err := c.Marshal()
  202. if err != nil {
  203. return nil, err
  204. }
  205. return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil
  206. }
  207. func (c *certificateV1) MarshalJSON() ([]byte, error) {
  208. return json.Marshal(c.marshalJSON())
  209. }
  210. func (c *certificateV1) marshalJSON() m {
  211. fp, _ := c.Fingerprint()
  212. return m{
  213. "version": Version1,
  214. "details": m{
  215. "name": c.details.name,
  216. "networks": c.details.networks,
  217. "unsafeNetworks": c.details.unsafeNetworks,
  218. "groups": c.details.groups,
  219. "notBefore": c.details.notBefore,
  220. "notAfter": c.details.notAfter,
  221. "publicKey": fmt.Sprintf("%x", c.details.publicKey),
  222. "isCa": c.details.isCA,
  223. "issuer": c.details.issuer,
  224. "curve": c.details.curve.String(),
  225. },
  226. "fingerprint": fp,
  227. "signature": fmt.Sprintf("%x", c.Signature()),
  228. }
  229. }
  230. func (c *certificateV1) Copy() Certificate {
  231. nc := &certificateV1{
  232. details: detailsV1{
  233. name: c.details.name,
  234. notBefore: c.details.notBefore,
  235. notAfter: c.details.notAfter,
  236. publicKey: make([]byte, len(c.details.publicKey)),
  237. isCA: c.details.isCA,
  238. issuer: c.details.issuer,
  239. curve: c.details.curve,
  240. },
  241. signature: make([]byte, len(c.signature)),
  242. }
  243. if c.details.groups != nil {
  244. nc.details.groups = make([]string, len(c.details.groups))
  245. copy(nc.details.groups, c.details.groups)
  246. }
  247. if c.details.networks != nil {
  248. nc.details.networks = make([]netip.Prefix, len(c.details.networks))
  249. copy(nc.details.networks, c.details.networks)
  250. }
  251. if c.details.unsafeNetworks != nil {
  252. nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
  253. copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
  254. }
  255. copy(nc.signature, c.signature)
  256. copy(nc.details.publicKey, c.details.publicKey)
  257. return nc
  258. }
  259. func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error {
  260. c.details = detailsV1{
  261. name: t.Name,
  262. networks: t.Networks,
  263. unsafeNetworks: t.UnsafeNetworks,
  264. groups: t.Groups,
  265. notBefore: t.NotBefore,
  266. notAfter: t.NotAfter,
  267. publicKey: t.PublicKey,
  268. isCA: t.IsCA,
  269. curve: t.Curve,
  270. issuer: t.issuer,
  271. }
  272. return nil
  273. }
  274. func (c *certificateV1) marshalForSigning() ([]byte, error) {
  275. b, err := proto.Marshal(c.getRawDetails())
  276. if err != nil {
  277. return nil, err
  278. }
  279. return b, nil
  280. }
  281. func (c *certificateV1) setSignature(b []byte) error {
  282. c.signature = b
  283. return nil
  284. }
  285. // unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert
  286. // if the publicKey is provided here then it is not required to be present in `b`
  287. func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) {
  288. if len(b) == 0 {
  289. return nil, fmt.Errorf("nil byte array")
  290. }
  291. var rc RawNebulaCertificate
  292. err := proto.Unmarshal(b, &rc)
  293. if err != nil {
  294. return nil, err
  295. }
  296. if rc.Details == nil {
  297. return nil, fmt.Errorf("encoded Details was nil")
  298. }
  299. if len(rc.Details.Ips)%2 != 0 {
  300. return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found")
  301. }
  302. if len(rc.Details.Subnets)%2 != 0 {
  303. return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found")
  304. }
  305. nc := certificateV1{
  306. details: detailsV1{
  307. name: rc.Details.Name,
  308. groups: make([]string, len(rc.Details.Groups)),
  309. networks: make([]netip.Prefix, len(rc.Details.Ips)/2),
  310. unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2),
  311. notBefore: time.Unix(rc.Details.NotBefore, 0),
  312. notAfter: time.Unix(rc.Details.NotAfter, 0),
  313. publicKey: make([]byte, len(rc.Details.PublicKey)),
  314. isCA: rc.Details.IsCA,
  315. curve: rc.Details.Curve,
  316. },
  317. signature: make([]byte, len(rc.Signature)),
  318. }
  319. copy(nc.signature, rc.Signature)
  320. copy(nc.details.groups, rc.Details.Groups)
  321. nc.details.issuer = hex.EncodeToString(rc.Details.Issuer)
  322. if len(publicKey) > 0 {
  323. nc.details.publicKey = publicKey
  324. }
  325. copy(nc.details.publicKey, rc.Details.PublicKey)
  326. var ip netip.Addr
  327. for i, rawIp := range rc.Details.Ips {
  328. if i%2 == 0 {
  329. ip = int2addr(rawIp)
  330. } else {
  331. ones, _ := net.IPMask(int2ip(rawIp)).Size()
  332. nc.details.networks[i/2] = netip.PrefixFrom(ip, ones)
  333. }
  334. }
  335. for i, rawIp := range rc.Details.Subnets {
  336. if i%2 == 0 {
  337. ip = int2addr(rawIp)
  338. } else {
  339. ones, _ := net.IPMask(int2ip(rawIp)).Size()
  340. nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones)
  341. }
  342. }
  343. return &nc, nil
  344. }
  345. func ip2int(ip []byte) uint32 {
  346. if len(ip) == 16 {
  347. return binary.BigEndian.Uint32(ip[12:16])
  348. }
  349. return binary.BigEndian.Uint32(ip)
  350. }
  351. func int2ip(nn uint32) net.IP {
  352. ip := make(net.IP, net.IPv4len)
  353. binary.BigEndian.PutUint32(ip, nn)
  354. return ip
  355. }
  356. func addr2int(addr netip.Addr) uint32 {
  357. b := addr.Unmap().As4()
  358. return binary.BigEndian.Uint32(b[:])
  359. }
  360. func int2addr(nn uint32) netip.Addr {
  361. ip := [4]byte{}
  362. binary.BigEndian.PutUint32(ip[:], nn)
  363. return netip.AddrFrom4(ip).Unmap()
  364. }