network_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. package controller
  2. import (
  3. "context"
  4. "os"
  5. "testing"
  6. "github.com/google/uuid"
  7. "github.com/gravitl/netmaker/database"
  8. "github.com/gravitl/netmaker/logger"
  9. "github.com/gravitl/netmaker/logic"
  10. "github.com/gravitl/netmaker/models"
  11. "github.com/stretchr/testify/assert"
  12. "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
  13. )
  14. type NetworkValidationTestCase struct {
  15. testname string
  16. network models.Network
  17. errMessage string
  18. }
  19. var netHost models.Host
  20. func TestMain(m *testing.M) {
  21. database.InitializeDatabase()
  22. defer database.CloseDB()
  23. logic.CreateAdmin(&models.User{
  24. UserName: "admin",
  25. Password: "password",
  26. IsAdmin: true,
  27. Networks: []string{},
  28. Groups: []string{},
  29. })
  30. peerUpdate := make(chan *models.Node)
  31. go logic.ManageZombies(context.Background(), peerUpdate)
  32. go func() {
  33. for update := range peerUpdate {
  34. //do nothing
  35. logger.Log(3, "received node update", update.Action)
  36. }
  37. }()
  38. os.Exit(m.Run())
  39. }
  40. func TestCreateNetwork(t *testing.T) {
  41. deleteAllNetworks()
  42. var network models.Network
  43. network.NetID = "skynet"
  44. network.AddressRange = "10.0.0.1/24"
  45. // if tests break - check here (removed displayname)
  46. //network.DisplayName = "mynetwork"
  47. _, err := logic.CreateNetwork(network)
  48. assert.Nil(t, err)
  49. }
  50. func TestGetNetwork(t *testing.T) {
  51. createNet()
  52. t.Run("GetExistingNetwork", func(t *testing.T) {
  53. network, err := logic.GetNetwork("skynet")
  54. assert.Nil(t, err)
  55. assert.Equal(t, "skynet", network.NetID)
  56. })
  57. t.Run("GetNonExistantNetwork", func(t *testing.T) {
  58. network, err := logic.GetNetwork("doesnotexist")
  59. assert.EqualError(t, err, "no result found")
  60. assert.Equal(t, "", network.NetID)
  61. })
  62. }
  63. func TestDeleteNetwork(t *testing.T) {
  64. createNet()
  65. //create nodes
  66. t.Run("NetworkwithNodes", func(t *testing.T) {
  67. })
  68. t.Run("DeleteExistingNetwork", func(t *testing.T) {
  69. err := logic.DeleteNetwork("skynet")
  70. assert.Nil(t, err)
  71. })
  72. t.Run("NonExistantNetwork", func(t *testing.T) {
  73. err := logic.DeleteNetwork("skynet")
  74. assert.Nil(t, err)
  75. })
  76. }
  77. func TestCreateKey(t *testing.T) {
  78. createNet()
  79. keys, _ := logic.GetKeys("skynet")
  80. for _, key := range keys {
  81. logic.DeleteKey(key.Name, "skynet")
  82. }
  83. var accesskey models.AccessKey
  84. var network models.Network
  85. network.NetID = "skynet"
  86. t.Run("NameTooLong", func(t *testing.T) {
  87. network, err := logic.GetNetwork("skynet")
  88. assert.Nil(t, err)
  89. accesskey.Name = "ThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfail"
  90. _, err = logic.CreateAccessKey(accesskey, network)
  91. assert.NotNil(t, err)
  92. assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'max' tag")
  93. })
  94. t.Run("BlankName", func(t *testing.T) {
  95. network, err := logic.GetNetwork("skynet")
  96. assert.Nil(t, err)
  97. accesskey.Name = ""
  98. key, err := logic.CreateAccessKey(accesskey, network)
  99. assert.Nil(t, err)
  100. assert.NotEqual(t, "", key.Name)
  101. })
  102. t.Run("InvalidValue", func(t *testing.T) {
  103. network, err := logic.GetNetwork("skynet")
  104. assert.Nil(t, err)
  105. accesskey.Value = "bad-value"
  106. _, err = logic.CreateAccessKey(accesskey, network)
  107. assert.NotNil(t, err)
  108. assert.Contains(t, err.Error(), "Field validation for 'Value' failed on the 'alphanum' tag")
  109. })
  110. t.Run("BlankValue", func(t *testing.T) {
  111. network, err := logic.GetNetwork("skynet")
  112. assert.Nil(t, err)
  113. accesskey.Name = "mykey"
  114. accesskey.Value = ""
  115. key, err := logic.CreateAccessKey(accesskey, network)
  116. assert.Nil(t, err)
  117. assert.NotEqual(t, "", key.Value)
  118. assert.Equal(t, accesskey.Name, key.Name)
  119. })
  120. t.Run("ValueTooLong", func(t *testing.T) {
  121. network, err := logic.GetNetwork("skynet")
  122. assert.Nil(t, err)
  123. accesskey.Name = "keyname"
  124. accesskey.Value = "AccessKeyValuethatistoolong"
  125. _, err = logic.CreateAccessKey(accesskey, network)
  126. assert.NotNil(t, err)
  127. assert.Contains(t, err.Error(), "Field validation for 'Value' failed on the 'max' tag")
  128. })
  129. t.Run("BlankUses", func(t *testing.T) {
  130. network, err := logic.GetNetwork("skynet")
  131. assert.Nil(t, err)
  132. accesskey.Uses = 0
  133. accesskey.Value = ""
  134. key, err := logic.CreateAccessKey(accesskey, network)
  135. assert.Nil(t, err)
  136. assert.Equal(t, 1, key.Uses)
  137. })
  138. t.Run("DuplicateKey", func(t *testing.T) {
  139. network, err := logic.GetNetwork("skynet")
  140. assert.Nil(t, err)
  141. accesskey.Name = "mykey"
  142. _, err = logic.CreateAccessKey(accesskey, network)
  143. assert.NotNil(t, err)
  144. assert.EqualError(t, err, "duplicate AccessKey Name")
  145. })
  146. }
  147. func TestGetKeys(t *testing.T) {
  148. deleteAllNetworks()
  149. createNet()
  150. network, err := logic.GetNetwork("skynet")
  151. assert.Nil(t, err)
  152. var key models.AccessKey
  153. key.Name = "mykey"
  154. _, err = logic.CreateAccessKey(key, network)
  155. assert.Nil(t, err)
  156. t.Run("KeyExists", func(t *testing.T) {
  157. keys, err := logic.GetKeys(network.NetID)
  158. assert.Nil(t, err)
  159. assert.NotEqual(t, models.AccessKey{}, keys)
  160. })
  161. t.Run("NonExistantKey", func(t *testing.T) {
  162. err := logic.DeleteKey("mykey", "skynet")
  163. assert.Nil(t, err)
  164. keys, err := logic.GetKeys(network.NetID)
  165. assert.Nil(t, err)
  166. assert.Equal(t, []models.AccessKey(nil), keys)
  167. })
  168. }
  169. func TestDeleteKey(t *testing.T) {
  170. createNet()
  171. network, err := logic.GetNetwork("skynet")
  172. assert.Nil(t, err)
  173. var key models.AccessKey
  174. key.Name = "mykey"
  175. _, err = logic.CreateAccessKey(key, network)
  176. assert.Nil(t, err)
  177. t.Run("ExistingKey", func(t *testing.T) {
  178. err := logic.DeleteKey("mykey", "skynet")
  179. assert.Nil(t, err)
  180. })
  181. t.Run("NonExistantKey", func(t *testing.T) {
  182. err := logic.DeleteKey("mykey", "skynet")
  183. assert.NotNil(t, err)
  184. assert.Equal(t, "key mykey does not exist", err.Error())
  185. })
  186. }
  187. func TestSecurityCheck(t *testing.T) {
  188. //these seem to work but not sure it the tests are really testing the functionality
  189. os.Setenv("MASTER_KEY", "secretkey")
  190. t.Run("NoNetwork", func(t *testing.T) {
  191. networks, username, err := logic.UserPermissions(false, "", "Bearer secretkey")
  192. assert.Nil(t, err)
  193. t.Log(networks, username)
  194. })
  195. t.Run("WithNetwork", func(t *testing.T) {
  196. networks, username, err := logic.UserPermissions(false, "skynet", "Bearer secretkey")
  197. assert.Nil(t, err)
  198. t.Log(networks, username)
  199. })
  200. t.Run("BadNet", func(t *testing.T) {
  201. t.Skip()
  202. networks, username, err := logic.UserPermissions(false, "badnet", "Bearer secretkey")
  203. assert.NotNil(t, err)
  204. t.Log(err)
  205. t.Log(networks, username)
  206. })
  207. t.Run("BadToken", func(t *testing.T) {
  208. networks, username, err := logic.UserPermissions(false, "skynet", "Bearer badkey")
  209. assert.NotNil(t, err)
  210. t.Log(err)
  211. t.Log(networks, username)
  212. })
  213. }
  214. func TestValidateNetwork(t *testing.T) {
  215. //t.Skip()
  216. //This functions is not called by anyone
  217. //it panics as validation function 'display_name_valid' is not defined
  218. //yes := true
  219. //no := false
  220. //deleteNet(t)
  221. //DeleteNetworks
  222. cases := []NetworkValidationTestCase{
  223. {
  224. testname: "InvalidAddress",
  225. network: models.Network{
  226. NetID: "skynet",
  227. AddressRange: "10.0.0.256",
  228. },
  229. errMessage: "Field validation for 'AddressRange' failed on the 'cidrv4' tag",
  230. },
  231. {
  232. testname: "InvalidAddress6",
  233. network: models.Network{
  234. NetID: "skynet1",
  235. AddressRange6: "2607::ffff/130",
  236. },
  237. errMessage: "Field validation for 'AddressRange6' failed on the 'cidrv6' tag",
  238. },
  239. {
  240. testname: "InvalidNetID",
  241. network: models.Network{
  242. NetID: "with spaces",
  243. },
  244. errMessage: "Field validation for 'NetID' failed on the 'netid_valid' tag",
  245. },
  246. {
  247. testname: "NetIDTooLong",
  248. network: models.Network{
  249. NetID: "LongNetIDName",
  250. },
  251. errMessage: "Field validation for 'NetID' failed on the 'max' tag",
  252. },
  253. {
  254. testname: "ListenPortTooLow",
  255. network: models.Network{
  256. NetID: "skynet",
  257. DefaultListenPort: 1023,
  258. },
  259. errMessage: "Field validation for 'DefaultListenPort' failed on the 'min' tag",
  260. },
  261. {
  262. testname: "ListenPortTooHigh",
  263. network: models.Network{
  264. NetID: "skynet",
  265. DefaultListenPort: 65536,
  266. },
  267. errMessage: "Field validation for 'DefaultListenPort' failed on the 'max' tag",
  268. },
  269. {
  270. testname: "KeepAliveTooBig",
  271. network: models.Network{
  272. NetID: "skynet",
  273. DefaultKeepalive: 1010,
  274. },
  275. errMessage: "Field validation for 'DefaultKeepalive' failed on the 'max' tag",
  276. },
  277. }
  278. for _, tc := range cases {
  279. t.Run(tc.testname, func(t *testing.T) {
  280. t.Log(tc.testname)
  281. network := models.Network(tc.network)
  282. network.SetDefaults()
  283. err := logic.ValidateNetwork(&network, false)
  284. assert.NotNil(t, err)
  285. assert.Contains(t, err.Error(), tc.errMessage) // test passes if err.Error() contains the expected errMessage.
  286. })
  287. }
  288. }
  289. func TestIpv6Network(t *testing.T) {
  290. //these seem to work but not sure it the tests are really testing the functionality
  291. os.Setenv("MASTER_KEY", "secretkey")
  292. deleteAllNetworks()
  293. createNet()
  294. createNetDualStack()
  295. network, err := logic.GetNetwork("skynet6")
  296. t.Run("Test Network Create IPv6", func(t *testing.T) {
  297. assert.Nil(t, err)
  298. assert.Equal(t, network.AddressRange6, "fde6:be04:fa5e:d076::/64")
  299. })
  300. node1 := createNodeWithParams("skynet6", "")
  301. createNetHost()
  302. nodeErr := logic.AssociateNodeToHost(node1, &netHost)
  303. t.Run("Test node on network IPv6", func(t *testing.T) {
  304. assert.Nil(t, nodeErr)
  305. assert.Equal(t, "fde6:be04:fa5e:d076::1", node1.Address6.IP.String())
  306. })
  307. }
  308. func deleteAllNetworks() {
  309. deleteAllNodes()
  310. nets, _ := logic.GetNetworks()
  311. for _, net := range nets {
  312. logic.DeleteNetwork(net.NetID)
  313. }
  314. }
  315. func createNet() {
  316. var network models.Network
  317. network.NetID = "skynet"
  318. network.AddressRange = "10.0.0.1/24"
  319. _, err := logic.GetNetwork("skynet")
  320. if err != nil {
  321. logic.CreateNetwork(network)
  322. }
  323. }
  324. func createNetDualStack() {
  325. var network models.Network
  326. network.NetID = "skynet6"
  327. network.AddressRange = "10.1.2.0/24"
  328. network.AddressRange6 = "fde6:be04:fa5e:d076::/64"
  329. network.IsIPv4 = "yes"
  330. network.IsIPv6 = "yes"
  331. _, err := logic.GetNetwork("skynet6")
  332. if err != nil {
  333. logic.CreateNetwork(network)
  334. }
  335. }
  336. func createNetHost() {
  337. k, _ := wgtypes.ParseKey("DM5qhLAE20PG9BbfBCger+Ac9D2NDOwCtY1rbYDLf34=")
  338. netHost = models.Host{
  339. ID: uuid.New(),
  340. PublicKey: k.PublicKey(),
  341. HostPass: "password",
  342. OS: "linux",
  343. Name: "nethost",
  344. }
  345. _ = logic.CreateHost(&netHost)
  346. }