handshake.go 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. package nebula
  2. import (
  3. "crypto/hmac"
  4. "crypto/sha256"
  5. "errors"
  6. "github.com/golang/protobuf/proto"
  7. )
  8. const (
  9. handshakeIXPSK0 = 0
  10. handshakeXXPSK0 = 1
  11. )
  12. func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) {
  13. newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex)
  14. //TODO: For stage 1 we won't have hostinfo yet but stage 2 and above would require it, this check may be helpful in those cases
  15. //if err != nil {
  16. // l.WithError(err).WithField("udpAddr", addr).Error("Error while finding host info for handshake message")
  17. // return
  18. //}
  19. tearDown := false
  20. switch h.Subtype {
  21. case handshakeIXPSK0:
  22. switch h.MessageCounter {
  23. case 1:
  24. tearDown = ixHandshakeStage1(f, addr, newHostinfo, packet, h)
  25. case 2:
  26. tearDown = ixHandshakeStage2(f, addr, newHostinfo, packet, h)
  27. }
  28. }
  29. if tearDown && newHostinfo != nil {
  30. f.handshakeManager.DeleteIndex(newHostinfo.localIndexId)
  31. f.handshakeManager.DeleteVpnIP(newHostinfo.hostId)
  32. }
  33. }
  34. func HandshakeBytesWithMAC(details *NebulaHandshakeDetails, key []byte) ([]byte, error) {
  35. mac := hmac.New(sha256.New, key)
  36. b, err := proto.Marshal(details)
  37. if err != nil {
  38. return nil, errors.New("Unable to marshal nebula handshake")
  39. }
  40. mac.Write(b)
  41. sum := mac.Sum(nil)
  42. hs := &NebulaHandshake{
  43. Details: details,
  44. Hmac: sum,
  45. }
  46. hsBytes, err := proto.Marshal(hs)
  47. if err != nil {
  48. l.Debugln("failed to generate NebulaHandshake protobuf", err)
  49. }
  50. return hsBytes, nil
  51. }
  52. func (hs *NebulaHandshake) CheckHandshakeMAC(keys [][]byte) bool {
  53. b, err := proto.Marshal(hs.Details)
  54. if err != nil {
  55. return false
  56. }
  57. for _, k := range keys {
  58. mac := hmac.New(sha256.New, k)
  59. mac.Write(b)
  60. expectedMAC := mac.Sum(nil)
  61. if hmac.Equal(hs.Hmac, expectedMAC) {
  62. return true
  63. }
  64. }
  65. //l.Debugln(hs.Hmac, expectedMAC)
  66. return false
  67. }