Browse Source

add emqx acls

Anish Mukherjee 2 years ago
parent
commit
8a9f569c4f
4 changed files with 239 additions and 0 deletions
  1. 17 0
      controllers/enrollmentkeys.go
  2. 4 0
      controllers/node.go
  3. 210 0
      mq/emqx.go
  4. 8 0
      mq/mq.go

+ 17 - 0
controllers/enrollmentkeys.go

@@ -182,9 +182,21 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("invalid enrollment key"), "badrequest"))
 		return
 	}
+	hostPass := newHost.HostPass
 	if !hostExists {
 		// register host
 		logic.CheckHostPorts(&newHost)
+		// create EMQX credentials and ACLs for host
+		if servercfg.GetBrokerType() == servercfg.EmqxBrokerType {
+			if err := mq.CreateEmqxUser(newHost.ID.String(), newHost.HostPass, false); err != nil {
+				logger.Log(0, "failed to create host credentials for EMQX: ", err.Error())
+				return
+			}
+			if err := mq.CreateHostACL(newHost.ID.String(), servercfg.GetServerInfo().Server); err != nil {
+				logger.Log(0, "failed to add host ACL rules to EMQX: ", err.Error())
+				return
+			}
+		}
 		if err = logic.CreateHost(&newHost); err != nil {
 			logger.Log(0, "host", newHost.ID.String(), newHost.Name, "failed registration -", err.Error())
 			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
@@ -205,6 +217,11 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
 	// ready the response
 	server := servercfg.GetServerInfo()
 	server.TrafficKey = key
+	if servercfg.GetBrokerType() == servercfg.EmqxBrokerType {
+		// set MQ username and password for EMQX clients
+		server.MQUserName = newHost.ID.String()
+		server.MQPassword = hostPass
+	}
 	response := models.RegisterResponse{
 		ServerConf:    server,
 		RequestedHost: newHost,

+ 4 - 0
controllers/node.go

@@ -394,6 +394,10 @@ func getNode(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 	server := servercfg.GetServerInfo()
+	if servercfg.GetBrokerType() == servercfg.EmqxBrokerType {
+		// set MQ username for EMQX clients
+		server.MQUserName = host.ID.String()
+	}
 	response := models.NodeGet{
 		Node:         node,
 		Host:         *host,

+ 210 - 0
mq/emqx.go

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"io"
 	"net/http"
+	"sync"
 
 	"github.com/gravitl/netmaker/servercfg"
 )
@@ -29,6 +30,17 @@ type (
 		Token   string `json:"token"`
 		Version string `json:"version"`
 	}
+
+	aclRule struct {
+		Topic      string `json:"topic"`
+		Permission string `json:"permission"`
+		Action     string `json:"action"`
+	}
+
+	aclObject struct {
+		Rules    []aclRule `json:"rules"`
+		Username string    `json:"username,omitempty"`
+	}
 )
 
 func getEmqxAuthToken() (string, error) {
@@ -152,3 +164,201 @@ func CreateEmqxDefaultAuthenticator() error {
 	}
 	return nil
 }
+
+// CreateEmqxDefaultAuthorizer - creates a default ACL authorization mechanism based on the built in database
+func CreateEmqxDefaultAuthorizer() error {
+	token, err := getEmqxAuthToken()
+	if err != nil {
+		return err
+	}
+	payload, err := json.Marshal(&struct {
+		Enable bool   `json:"enable"`
+		Type   string `json:"type"`
+	}{Enable: true, Type: "built_in_database"})
+	if err != nil {
+		return err
+	}
+	req, err := http.NewRequest(http.MethodPost, servercfg.GetEmqxRestEndpoint()+"/api/v5/authorization/sources", bytes.NewReader(payload))
+	if err != nil {
+		return err
+	}
+	req.Header.Add("content-type", "application/json")
+	req.Header.Add("authorization", "Bearer "+token)
+	resp, err := (&http.Client{}).Do(req)
+	if err != nil {
+		return err
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode != http.StatusNoContent {
+		msg, err := io.ReadAll(resp.Body)
+		if err != nil {
+			return err
+		}
+		return fmt.Errorf("error creating default EMQX ACL authorization mechanism %v", string(msg))
+	}
+	return nil
+}
+
+// GetUserACL - returns ACL rules by username
+func GetUserACL(username string) (*aclObject, error) {
+	token, err := getEmqxAuthToken()
+	if err != nil {
+		return nil, err
+	}
+	req, err := http.NewRequest(http.MethodGet, servercfg.GetEmqxRestEndpoint()+"/api/v5/authorization/sources/built_in_database/rules/users/"+username, nil)
+	if err != nil {
+		return nil, err
+	}
+	req.Header.Add("content-type", "application/json")
+	req.Header.Add("authorization", "Bearer "+token)
+	resp, err := (&http.Client{}).Do(req)
+	if err != nil {
+		return nil, err
+	}
+	defer resp.Body.Close()
+	response, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return nil, err
+	}
+	if resp.StatusCode != http.StatusOK {
+		return nil, fmt.Errorf("error fetching ACL rules %v", string(response))
+	}
+	body := new(aclObject)
+	if err := json.Unmarshal(response, body); err != nil {
+		return nil, err
+	}
+	return body, nil
+}
+
+// CreateDefaultDenyRule - creates a rule to deny access to all topics for all users by default
+// to allow user access to topics use the `mq.CreateUserAccessRule` function
+func CreateDefaultDenyRule() error {
+	token, err := getEmqxAuthToken()
+	if err != nil {
+		return err
+	}
+	payload, err := json.Marshal(&aclObject{Rules: []aclRule{{Topic: "#", Permission: "deny", Action: "all"}}})
+	if err != nil {
+		return err
+	}
+	req, err := http.NewRequest(http.MethodPost, servercfg.GetEmqxRestEndpoint()+"/api/v5/authorization/sources/built_in_database/rules/all", bytes.NewReader(payload))
+	if err != nil {
+		return err
+	}
+	req.Header.Add("content-type", "application/json")
+	req.Header.Add("authorization", "Bearer "+token)
+	resp, err := (&http.Client{}).Do(req)
+	if err != nil {
+		return err
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode != http.StatusNoContent {
+		msg, err := io.ReadAll(resp.Body)
+		if err != nil {
+			return err
+		}
+		return fmt.Errorf("error creating default ACL rules %v", string(msg))
+	}
+	return nil
+}
+
+// CreateHostACL - create host ACL rules
+func CreateHostACL(hostID, serverName string) error {
+	token, err := getEmqxAuthToken()
+	if err != nil {
+		return err
+	}
+	payload, err := json.Marshal(&aclObject{
+		Username: hostID,
+		Rules: []aclRule{
+			{
+				Topic:      fmt.Sprintf("peers/host/%s/%s", hostID, serverName),
+				Permission: "allow",
+				Action:     "all",
+			},
+			{
+				Topic:      fmt.Sprintf("host/update/%s/%s", hostID, serverName),
+				Permission: "allow",
+				Action:     "all",
+			},
+			{
+				Topic:      fmt.Sprintf("dns/all/%s/%s", hostID, serverName),
+				Permission: "allow",
+				Action:     "all",
+			},
+			{
+				Topic:      fmt.Sprintf("dns/update/%s/%s", hostID, serverName),
+				Permission: "allow",
+				Action:     "all",
+			},
+		},
+	})
+	if err != nil {
+		return err
+	}
+	req, err := http.NewRequest(http.MethodPut, servercfg.GetEmqxRestEndpoint()+"/api/v5/authorization/sources/built_in_database/rules/users/"+hostID, bytes.NewReader(payload))
+	if err != nil {
+		return err
+	}
+	req.Header.Add("content-type", "application/json")
+	req.Header.Add("authorization", "Bearer "+token)
+	resp, err := (&http.Client{}).Do(req)
+	if err != nil {
+		return err
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode != http.StatusNoContent {
+		msg, err := io.ReadAll(resp.Body)
+		if err != nil {
+			return err
+		}
+		return fmt.Errorf("error adding ACL Rules for user %s Error: %v", hostID, string(msg))
+	}
+	return nil
+}
+
+// a lock required for preventing simultaneous updates to the same ACL object leading to overwriting each other
+// might occur when multiple nodes belonging to the same host are created at the same time
+var nodeAclMux sync.Mutex
+
+// AppendNodeUpdateACL - adds ACL rule for subscribing to node updates for a node ID
+func AppendNodeUpdateACL(hostID, nodeNetwork, nodeID string) error {
+	nodeAclMux.Lock()
+	defer nodeAclMux.Unlock()
+	token, err := getEmqxAuthToken()
+	if err != nil {
+		return err
+	}
+	aclObject, err := GetUserACL(hostID)
+	if err != nil {
+		return err
+	}
+	aclObject.Rules = append(aclObject.Rules, aclRule{
+		Topic:      fmt.Sprintf("node/update/%s/%s", nodeNetwork, nodeID),
+		Permission: "allow",
+		Action:     "subscribe",
+	})
+	payload, err := json.Marshal(aclObject)
+	if err != nil {
+		return err
+	}
+	req, err := http.NewRequest(http.MethodPut, servercfg.GetEmqxRestEndpoint()+"/api/v5/authorization/sources/built_in_database/rules/users/"+hostID, bytes.NewReader(payload))
+	if err != nil {
+		return err
+	}
+	req.Header.Add("content-type", "application/json")
+	req.Header.Add("authorization", "Bearer "+token)
+	resp, err := (&http.Client{}).Do(req)
+	if err != nil {
+		return err
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode != http.StatusNoContent {
+		msg, err := io.ReadAll(resp.Body)
+		if err != nil {
+			return err
+		}
+		return fmt.Errorf("error adding ACL Rules for user %s Error: %v", hostID, string(msg))
+	}
+	return nil
+}

+ 8 - 0
mq/mq.go

@@ -50,6 +50,14 @@ func SetupMQTT() {
 		if err := CreateEmqxUser(servercfg.GetMqUserName(), servercfg.GetMqPassword(), true); err != nil {
 			log.Fatal(err)
 		}
+		// create an ACL authorization source for the built in EMQX MNESIA database
+		if err := CreateEmqxDefaultAuthorizer(); err != nil {
+			logger.Log(0, err.Error())
+		}
+		// create a default deny ACL to all topics for all users
+		if err := CreateDefaultDenyRule(); err != nil {
+			log.Fatal(err)
+		}
 	}
 	opts := mqtt.NewClientOptions()
 	setMqOptions(servercfg.GetMqUserName(), servercfg.GetMqPassword(), opts)