node_test.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. package controller
  2. import (
  3. "net"
  4. "testing"
  5. "github.com/google/uuid"
  6. "github.com/gravitl/netmaker/database"
  7. "github.com/gravitl/netmaker/logic"
  8. "github.com/gravitl/netmaker/logic/acls"
  9. "github.com/gravitl/netmaker/logic/acls/nodeacls"
  10. "github.com/gravitl/netmaker/models"
  11. "github.com/gravitl/netmaker/servercfg"
  12. "github.com/stretchr/testify/assert"
  13. "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
  14. )
  15. var nonLinuxHost models.Host
  16. var linuxHost models.Host
  17. func TestGetNetworkNodes(t *testing.T) {
  18. deleteAllNetworks()
  19. createNet()
  20. t.Run("BadNet", func(t *testing.T) {
  21. node, err := logic.GetNetworkNodes("badnet")
  22. assert.Nil(t, err)
  23. assert.Equal(t, []models.Node{}, node)
  24. })
  25. t.Run("NoNodes", func(t *testing.T) {
  26. node, err := logic.GetNetworkNodes("skynet")
  27. assert.Nil(t, err)
  28. assert.Equal(t, []models.Node{}, node)
  29. })
  30. t.Run("Success", func(t *testing.T) {
  31. createTestNode()
  32. node, err := logic.GetNetworkNodes("skynet")
  33. assert.Nil(t, err)
  34. assert.NotEqual(t, []models.LegacyNode(nil), node)
  35. })
  36. }
  37. func TestValidateEgressGateway(t *testing.T) {
  38. var gateway models.EgressGatewayRequest
  39. t.Run("Success", func(t *testing.T) {
  40. gateway.Ranges = []string{"10.100.100.0/24"}
  41. err := logic.ValidateEgressGateway(gateway)
  42. assert.Nil(t, err)
  43. })
  44. }
  45. func TestNodeACLs(t *testing.T) {
  46. deleteAllNodes()
  47. node1 := createNodeWithParams("", "10.0.0.50/32")
  48. node2 := createNodeWithParams("", "10.0.0.100/32")
  49. logic.AssociateNodeToHost(node1, &linuxHost)
  50. logic.AssociateNodeToHost(node2, &linuxHost)
  51. t.Run("acls not present", func(t *testing.T) {
  52. currentACL, err := nodeacls.FetchAllACLs(nodeacls.NetworkID(node1.Network))
  53. assert.Nil(t, err)
  54. assert.NotNil(t, currentACL)
  55. node1ACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID.String()))
  56. assert.Nil(t, err)
  57. assert.NotNil(t, node1ACL)
  58. assert.Equal(t, acls.Allowed, node1ACL[acls.AclID(node2.ID.String())])
  59. })
  60. t.Run("node acls exists after creates", func(t *testing.T) {
  61. node1ACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID.String()))
  62. assert.Nil(t, err)
  63. assert.NotNil(t, node1ACL)
  64. node2ACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(node2.Network), nodeacls.NodeID(node2.ID.String()))
  65. assert.Nil(t, err)
  66. assert.NotNil(t, node2ACL)
  67. assert.Equal(t, acls.Allowed, node2ACL[acls.AclID(node1.ID.String())])
  68. })
  69. t.Run("node acls correct after fetch", func(t *testing.T) {
  70. node1ACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID.String()))
  71. assert.Nil(t, err)
  72. assert.Equal(t, acls.Allowed, node1ACL[acls.AclID(node2.ID.String())])
  73. })
  74. t.Run("node acls correct after modify", func(t *testing.T) {
  75. node1ACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID.String()))
  76. assert.Nil(t, err)
  77. assert.NotNil(t, node1ACL)
  78. node2ACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(node2.Network), nodeacls.NodeID(node2.ID.String()))
  79. assert.Nil(t, err)
  80. assert.NotNil(t, node2ACL)
  81. currentACL, err := nodeacls.DisallowNodes(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID.String()), nodeacls.NodeID(node2.ID.String()))
  82. assert.Nil(t, err)
  83. assert.Equal(t, acls.NotAllowed, currentACL[acls.AclID(node1.ID.String())][acls.AclID(node2.ID.String())])
  84. assert.Equal(t, acls.NotAllowed, currentACL[acls.AclID(node2.ID.String())][acls.AclID(node1.ID.String())])
  85. currentACL.Save(acls.ContainerID(node1.Network))
  86. })
  87. t.Run("node acls correct after add new node not allowed", func(t *testing.T) {
  88. node3 := createNodeWithParams("", "10.0.0.100/32")
  89. createNodeHosts()
  90. n, e := logic.GetNetwork(node3.Network)
  91. assert.Nil(t, e)
  92. n.DefaultACL = "no"
  93. e = logic.SaveNetwork(&n)
  94. assert.Nil(t, e)
  95. err := logic.AssociateNodeToHost(node3, &linuxHost)
  96. assert.Nil(t, err)
  97. currentACL, err := nodeacls.FetchAllACLs(nodeacls.NetworkID(node3.Network))
  98. assert.Nil(t, err)
  99. assert.NotNil(t, currentACL)
  100. assert.Equal(t, acls.NotAllowed, currentACL[acls.AclID(node1.ID.String())][acls.AclID(node3.ID.String())])
  101. nodeACL, err := nodeacls.CreateNodeACL(nodeacls.NetworkID(node3.Network), nodeacls.NodeID(node3.ID.String()), acls.NotAllowed)
  102. assert.Nil(t, err)
  103. nodeACL.Save(acls.ContainerID(node3.Network), acls.AclID(node3.ID.String()))
  104. currentACL, err = nodeacls.FetchAllACLs(nodeacls.NetworkID(node3.Network))
  105. assert.Nil(t, err)
  106. assert.Equal(t, acls.NotAllowed, currentACL[acls.AclID(node1.ID.String())][acls.AclID(node3.ID.String())])
  107. assert.Equal(t, acls.NotAllowed, currentACL[acls.AclID(node2.ID.String())][acls.AclID(node3.ID.String())])
  108. })
  109. t.Run("node acls removed", func(t *testing.T) {
  110. retNetworkACL, err := nodeacls.RemoveNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID.String()))
  111. assert.Nil(t, err)
  112. assert.NotNil(t, retNetworkACL)
  113. assert.Equal(t, acls.NotPresent, retNetworkACL[acls.AclID(node2.ID.String())][acls.AclID(node1.ID.String())])
  114. })
  115. deleteAllNodes()
  116. }
  117. func deleteAllNodes() {
  118. if servercfg.CacheEnabled() {
  119. logic.ClearNodeCache()
  120. }
  121. database.DeleteAllRecords(database.NODES_TABLE_NAME)
  122. }
  123. func createTestNode() *models.Node {
  124. createNodeHosts()
  125. n := createNodeWithParams("skynet", "")
  126. _ = logic.AssociateNodeToHost(n, &linuxHost)
  127. return n
  128. }
  129. func createNodeWithParams(network, address string) *models.Node {
  130. _, ipnet, _ := net.ParseCIDR("10.0.0.1/32")
  131. tmpCNode := models.CommonNode{
  132. ID: uuid.New(),
  133. Network: "skynet",
  134. Address: *ipnet,
  135. }
  136. if len(network) > 0 {
  137. tmpCNode.Network = network
  138. }
  139. if len(address) > 0 {
  140. _, ipnet2, _ := net.ParseCIDR(address)
  141. tmpCNode.Address = *ipnet2
  142. }
  143. createnode := models.Node{
  144. CommonNode: tmpCNode,
  145. }
  146. return &createnode
  147. }
  148. func createNodeHosts() {
  149. k, _ := wgtypes.ParseKey("DM5qhLAE20PG9BbfBCger+Ac9D2NDOwCtY1rbYDLf34=")
  150. linuxHost = models.Host{
  151. ID: uuid.New(),
  152. PublicKey: k.PublicKey(),
  153. HostPass: "password",
  154. OS: "linux",
  155. Name: "linuxhost",
  156. }
  157. _ = logic.CreateHost(&linuxHost)
  158. nonLinuxHost = models.Host{
  159. ID: uuid.New(),
  160. OS: "windows",
  161. PublicKey: k.PublicKey(),
  162. Name: "windowshost",
  163. HostPass: "password",
  164. }
  165. _ = logic.CreateHost(&nonLinuxHost)
  166. }