2
0

network_test.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. package controller
  2. import (
  3. "context"
  4. "github.com/gravitl/netmaker/db"
  5. "github.com/gravitl/netmaker/schema"
  6. "os"
  7. "testing"
  8. "github.com/google/uuid"
  9. "github.com/gravitl/netmaker/database"
  10. "github.com/gravitl/netmaker/logger"
  11. "github.com/gravitl/netmaker/logic"
  12. "github.com/gravitl/netmaker/models"
  13. "github.com/stretchr/testify/assert"
  14. "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
  15. )
  16. type NetworkValidationTestCase struct {
  17. testname string
  18. network models.Network
  19. errMessage string
  20. }
  21. var netHost models.Host
  22. func TestMain(m *testing.M) {
  23. db.InitializeDB(schema.ListModels()...)
  24. defer db.CloseDB()
  25. database.InitializeDatabase()
  26. defer database.CloseDB()
  27. logic.CreateSuperAdmin(&models.User{
  28. UserName: "admin",
  29. Password: "password",
  30. PlatformRoleID: models.SuperAdminRole,
  31. })
  32. peerUpdate := make(chan *models.Node)
  33. go logic.ManageZombies(context.Background(), peerUpdate)
  34. go func() {
  35. for update := range peerUpdate {
  36. //do nothing
  37. logger.Log(3, "received node update", update.Action)
  38. }
  39. }()
  40. os.Exit(m.Run())
  41. }
  42. func TestCreateNetwork(t *testing.T) {
  43. deleteAllNetworks()
  44. var network models.Network
  45. network.NetID = "skynet1"
  46. network.AddressRange = "10.10.0.1/24"
  47. // if tests break - check here (removed displayname)
  48. //network.DisplayName = "mynetwork"
  49. _, err := logic.CreateNetwork(network)
  50. assert.Nil(t, err)
  51. }
  52. func TestGetNetwork(t *testing.T) {
  53. createNet()
  54. t.Run("GetExistingNetwork", func(t *testing.T) {
  55. network, err := logic.GetNetwork("skynet")
  56. assert.Nil(t, err)
  57. assert.Equal(t, "skynet", network.NetID)
  58. })
  59. t.Run("GetNonExistantNetwork", func(t *testing.T) {
  60. network, err := logic.GetNetwork("doesnotexist")
  61. assert.EqualError(t, err, "no result found")
  62. assert.Equal(t, "", network.NetID)
  63. })
  64. }
  65. func TestDeleteNetwork(t *testing.T) {
  66. createNet()
  67. //create nodes
  68. t.Run("NetworkwithNodes", func(t *testing.T) {
  69. })
  70. t.Run("DeleteExistingNetwork", func(t *testing.T) {
  71. doneCh := make(chan struct{}, 1)
  72. err := logic.DeleteNetwork("skynet", false, doneCh)
  73. assert.Nil(t, err)
  74. })
  75. t.Run("NonExistentNetwork", func(t *testing.T) {
  76. doneCh := make(chan struct{}, 1)
  77. err := logic.DeleteNetwork("skynet", false, doneCh)
  78. assert.Nil(t, err)
  79. })
  80. createNetv1("test")
  81. t.Run("ForceDeleteNetwork", func(t *testing.T) {
  82. doneCh := make(chan struct{}, 1)
  83. err := logic.DeleteNetwork("test", true, doneCh)
  84. assert.Nil(t, err)
  85. })
  86. }
  87. func TestSecurityCheck(t *testing.T) {
  88. //these seem to work but not sure it the tests are really testing the functionality
  89. os.Setenv("MASTER_KEY", "secretkey")
  90. t.Run("NoNetwork", func(t *testing.T) {
  91. username, err := logic.UserPermissions(false, "Bearer secretkey")
  92. assert.Nil(t, err)
  93. t.Log(username)
  94. })
  95. t.Run("BadToken", func(t *testing.T) {
  96. username, err := logic.UserPermissions(false, "Bearer badkey")
  97. assert.NotNil(t, err)
  98. t.Log(err)
  99. t.Log(username)
  100. })
  101. }
  102. func TestValidateNetwork(t *testing.T) {
  103. //t.Skip()
  104. //This functions is not called by anyone
  105. //it panics as validation function 'display_name_valid' is not defined
  106. //yes := true
  107. //no := false
  108. //deleteNet(t)
  109. //DeleteNetworks
  110. cases := []NetworkValidationTestCase{
  111. {
  112. testname: "InvalidAddress",
  113. network: models.Network{
  114. NetID: "skynet",
  115. AddressRange: "10.0.0.256",
  116. },
  117. errMessage: "Field validation for 'AddressRange' failed on the 'cidrv4' tag",
  118. },
  119. {
  120. testname: "InvalidAddress6",
  121. network: models.Network{
  122. NetID: "skynet1",
  123. AddressRange6: "2607::ffff/130",
  124. },
  125. errMessage: "Field validation for 'AddressRange6' failed on the 'cidrv6' tag",
  126. },
  127. {
  128. testname: "InvalidNetID",
  129. network: models.Network{
  130. NetID: "with spaces",
  131. },
  132. errMessage: "Field validation for 'NetID' failed on the 'netid_valid' tag",
  133. },
  134. {
  135. testname: "NetIDTooLong",
  136. network: models.Network{
  137. NetID: "LongNetIDNameForMaxCharactersTest",
  138. },
  139. errMessage: "Field validation for 'NetID' failed on the 'max' tag",
  140. },
  141. {
  142. testname: "ListenPortTooLow",
  143. network: models.Network{
  144. NetID: "skynet",
  145. DefaultListenPort: 1023,
  146. },
  147. errMessage: "Field validation for 'DefaultListenPort' failed on the 'min' tag",
  148. },
  149. {
  150. testname: "ListenPortTooHigh",
  151. network: models.Network{
  152. NetID: "skynet",
  153. DefaultListenPort: 65536,
  154. },
  155. errMessage: "Field validation for 'DefaultListenPort' failed on the 'max' tag",
  156. },
  157. {
  158. testname: "KeepAliveTooBig",
  159. network: models.Network{
  160. NetID: "skynet",
  161. DefaultKeepalive: 1010,
  162. },
  163. errMessage: "Field validation for 'DefaultKeepalive' failed on the 'max' tag",
  164. },
  165. }
  166. for _, tc := range cases {
  167. t.Run(tc.testname, func(t *testing.T) {
  168. t.Log(tc.testname)
  169. network := models.Network(tc.network)
  170. network.SetDefaults()
  171. err := logic.ValidateNetwork(&network, false)
  172. assert.NotNil(t, err)
  173. assert.Contains(t, err.Error(), tc.errMessage) // test passes if err.Error() contains the expected errMessage.
  174. })
  175. }
  176. }
  177. func TestIpv6Network(t *testing.T) {
  178. //these seem to work but not sure it the tests are really testing the functionality
  179. os.Setenv("MASTER_KEY", "secretkey")
  180. deleteAllNetworks()
  181. createNet()
  182. createNetDualStack()
  183. network, err := logic.GetNetwork("skynet6")
  184. t.Run("Test Network Create IPv6", func(t *testing.T) {
  185. assert.Nil(t, err)
  186. assert.Equal(t, network.AddressRange6, "fde6:be04:fa5e:d076::/64")
  187. })
  188. node1 := createNodeWithParams("skynet6", "")
  189. createNetHost()
  190. nodeErr := logic.AssociateNodeToHost(node1, &netHost)
  191. t.Run("Test node on network IPv6", func(t *testing.T) {
  192. assert.Nil(t, nodeErr)
  193. assert.Equal(t, "fde6:be04:fa5e:d076::1", node1.Address6.IP.String())
  194. })
  195. }
  196. func deleteAllNetworks() {
  197. deleteAllNodes()
  198. database.DeleteAllRecords(database.NETWORKS_TABLE_NAME)
  199. }
  200. func createNet() {
  201. var network models.Network
  202. network.NetID = "skynet"
  203. network.AddressRange = "10.0.0.1/24"
  204. _, err := logic.GetNetwork("skynet")
  205. if err != nil {
  206. logic.CreateNetwork(network)
  207. }
  208. }
  209. func createNetv1(netId string) {
  210. var network models.Network
  211. network.NetID = netId
  212. network.AddressRange = "100.0.0.1/24"
  213. _, err := logic.GetNetwork(netId)
  214. if err != nil {
  215. logic.CreateNetwork(network)
  216. }
  217. }
  218. func createNetDualStack() {
  219. var network models.Network
  220. network.NetID = "skynet6"
  221. network.AddressRange = "10.1.2.0/24"
  222. network.AddressRange6 = "fde6:be04:fa5e:d076::/64"
  223. network.IsIPv4 = "yes"
  224. network.IsIPv6 = "yes"
  225. _, err := logic.GetNetwork("skynet6")
  226. if err != nil {
  227. logic.CreateNetwork(network)
  228. }
  229. }
  230. func createNetHost() {
  231. k, _ := wgtypes.ParseKey("DM5qhLAE20PG9BbfBCger+Ac9D2NDOwCtY1rbYDLf34=")
  232. netHost = models.Host{
  233. ID: uuid.New(),
  234. PublicKey: k.PublicKey(),
  235. HostPass: "password",
  236. OS: "linux",
  237. Name: "nethost",
  238. }
  239. _ = logic.CreateHost(&netHost)
  240. }