network_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. package controller
  2. import (
  3. "context"
  4. "os"
  5. "testing"
  6. "github.com/google/uuid"
  7. "github.com/gravitl/netmaker/database"
  8. "github.com/gravitl/netmaker/logic"
  9. "github.com/gravitl/netmaker/models"
  10. "github.com/stretchr/testify/assert"
  11. "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
  12. )
  13. type NetworkValidationTestCase struct {
  14. testname string
  15. network models.Network
  16. errMessage string
  17. }
  18. var netHost models.Host
  19. func TestCreateNetwork(t *testing.T) {
  20. initialize()
  21. deleteAllNetworks()
  22. var network models.Network
  23. network.NetID = "skynet"
  24. network.AddressRange = "10.0.0.1/24"
  25. // if tests break - check here (removed displayname)
  26. //network.DisplayName = "mynetwork"
  27. _, err := logic.CreateNetwork(network)
  28. assert.Nil(t, err)
  29. }
  30. func TestGetNetwork(t *testing.T) {
  31. initialize()
  32. createNet()
  33. t.Run("GetExistingNetwork", func(t *testing.T) {
  34. network, err := logic.GetNetwork("skynet")
  35. assert.Nil(t, err)
  36. assert.Equal(t, "skynet", network.NetID)
  37. })
  38. t.Run("GetNonExistantNetwork", func(t *testing.T) {
  39. network, err := logic.GetNetwork("doesnotexist")
  40. assert.EqualError(t, err, "no result found")
  41. assert.Equal(t, "", network.NetID)
  42. })
  43. }
  44. func TestDeleteNetwork(t *testing.T) {
  45. initialize()
  46. createNet()
  47. //create nodes
  48. t.Run("NetworkwithNodes", func(t *testing.T) {
  49. })
  50. t.Run("DeleteExistingNetwork", func(t *testing.T) {
  51. err := logic.DeleteNetwork("skynet")
  52. assert.Nil(t, err)
  53. })
  54. t.Run("NonExistantNetwork", func(t *testing.T) {
  55. err := logic.DeleteNetwork("skynet")
  56. assert.Nil(t, err)
  57. })
  58. }
  59. func TestCreateKey(t *testing.T) {
  60. initialize()
  61. createNet()
  62. keys, _ := logic.GetKeys("skynet")
  63. for _, key := range keys {
  64. logic.DeleteKey(key.Name, "skynet")
  65. }
  66. var accesskey models.AccessKey
  67. var network models.Network
  68. network.NetID = "skynet"
  69. t.Run("NameTooLong", func(t *testing.T) {
  70. network, err := logic.GetNetwork("skynet")
  71. assert.Nil(t, err)
  72. accesskey.Name = "ThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfail"
  73. _, err = logic.CreateAccessKey(accesskey, network)
  74. assert.NotNil(t, err)
  75. assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'max' tag")
  76. })
  77. t.Run("BlankName", func(t *testing.T) {
  78. network, err := logic.GetNetwork("skynet")
  79. assert.Nil(t, err)
  80. accesskey.Name = ""
  81. key, err := logic.CreateAccessKey(accesskey, network)
  82. assert.Nil(t, err)
  83. assert.NotEqual(t, "", key.Name)
  84. })
  85. t.Run("InvalidValue", func(t *testing.T) {
  86. network, err := logic.GetNetwork("skynet")
  87. assert.Nil(t, err)
  88. accesskey.Value = "bad-value"
  89. _, err = logic.CreateAccessKey(accesskey, network)
  90. assert.NotNil(t, err)
  91. assert.Contains(t, err.Error(), "Field validation for 'Value' failed on the 'alphanum' tag")
  92. })
  93. t.Run("BlankValue", func(t *testing.T) {
  94. network, err := logic.GetNetwork("skynet")
  95. assert.Nil(t, err)
  96. accesskey.Name = "mykey"
  97. accesskey.Value = ""
  98. key, err := logic.CreateAccessKey(accesskey, network)
  99. assert.Nil(t, err)
  100. assert.NotEqual(t, "", key.Value)
  101. assert.Equal(t, accesskey.Name, key.Name)
  102. })
  103. t.Run("ValueTooLong", func(t *testing.T) {
  104. network, err := logic.GetNetwork("skynet")
  105. assert.Nil(t, err)
  106. accesskey.Name = "keyname"
  107. accesskey.Value = "AccessKeyValuethatistoolong"
  108. _, err = logic.CreateAccessKey(accesskey, network)
  109. assert.NotNil(t, err)
  110. assert.Contains(t, err.Error(), "Field validation for 'Value' failed on the 'max' tag")
  111. })
  112. t.Run("BlankUses", func(t *testing.T) {
  113. network, err := logic.GetNetwork("skynet")
  114. assert.Nil(t, err)
  115. accesskey.Uses = 0
  116. accesskey.Value = ""
  117. key, err := logic.CreateAccessKey(accesskey, network)
  118. assert.Nil(t, err)
  119. assert.Equal(t, 1, key.Uses)
  120. })
  121. t.Run("DuplicateKey", func(t *testing.T) {
  122. network, err := logic.GetNetwork("skynet")
  123. assert.Nil(t, err)
  124. accesskey.Name = "mykey"
  125. _, err = logic.CreateAccessKey(accesskey, network)
  126. assert.NotNil(t, err)
  127. assert.EqualError(t, err, "duplicate AccessKey Name")
  128. })
  129. }
  130. func TestGetKeys(t *testing.T) {
  131. initialize()
  132. deleteAllNetworks()
  133. createNet()
  134. network, err := logic.GetNetwork("skynet")
  135. assert.Nil(t, err)
  136. var key models.AccessKey
  137. key.Name = "mykey"
  138. _, err = logic.CreateAccessKey(key, network)
  139. assert.Nil(t, err)
  140. t.Run("KeyExists", func(t *testing.T) {
  141. keys, err := logic.GetKeys(network.NetID)
  142. assert.Nil(t, err)
  143. assert.NotEqual(t, models.AccessKey{}, keys)
  144. })
  145. t.Run("NonExistantKey", func(t *testing.T) {
  146. err := logic.DeleteKey("mykey", "skynet")
  147. assert.Nil(t, err)
  148. keys, err := logic.GetKeys(network.NetID)
  149. assert.Nil(t, err)
  150. assert.Equal(t, []models.AccessKey(nil), keys)
  151. })
  152. }
  153. func TestDeleteKey(t *testing.T) {
  154. initialize()
  155. createNet()
  156. network, err := logic.GetNetwork("skynet")
  157. assert.Nil(t, err)
  158. var key models.AccessKey
  159. key.Name = "mykey"
  160. _, err = logic.CreateAccessKey(key, network)
  161. assert.Nil(t, err)
  162. t.Run("ExistingKey", func(t *testing.T) {
  163. err := logic.DeleteKey("mykey", "skynet")
  164. assert.Nil(t, err)
  165. })
  166. t.Run("NonExistantKey", func(t *testing.T) {
  167. err := logic.DeleteKey("mykey", "skynet")
  168. assert.NotNil(t, err)
  169. assert.Equal(t, "key mykey does not exist", err.Error())
  170. })
  171. }
  172. func TestSecurityCheck(t *testing.T) {
  173. //these seem to work but not sure it the tests are really testing the functionality
  174. initialize()
  175. os.Setenv("MASTER_KEY", "secretkey")
  176. t.Run("NoNetwork", func(t *testing.T) {
  177. networks, username, err := logic.UserPermissions(false, "", "Bearer secretkey")
  178. assert.Nil(t, err)
  179. t.Log(networks, username)
  180. })
  181. t.Run("WithNetwork", func(t *testing.T) {
  182. networks, username, err := logic.UserPermissions(false, "skynet", "Bearer secretkey")
  183. assert.Nil(t, err)
  184. t.Log(networks, username)
  185. })
  186. t.Run("BadNet", func(t *testing.T) {
  187. t.Skip()
  188. networks, username, err := logic.UserPermissions(false, "badnet", "Bearer secretkey")
  189. assert.NotNil(t, err)
  190. t.Log(err)
  191. t.Log(networks, username)
  192. })
  193. t.Run("BadToken", func(t *testing.T) {
  194. networks, username, err := logic.UserPermissions(false, "skynet", "Bearer badkey")
  195. assert.NotNil(t, err)
  196. t.Log(err)
  197. t.Log(networks, username)
  198. })
  199. }
  200. func TestValidateNetwork(t *testing.T) {
  201. //t.Skip()
  202. //This functions is not called by anyone
  203. //it panics as validation function 'display_name_valid' is not defined
  204. initialize()
  205. //yes := true
  206. //no := false
  207. //deleteNet(t)
  208. //DeleteNetworks
  209. cases := []NetworkValidationTestCase{
  210. {
  211. testname: "InvalidAddress",
  212. network: models.Network{
  213. NetID: "skynet",
  214. AddressRange: "10.0.0.256",
  215. },
  216. errMessage: "Field validation for 'AddressRange' failed on the 'cidrv4' tag",
  217. },
  218. {
  219. testname: "InvalidAddress6",
  220. network: models.Network{
  221. NetID: "skynet1",
  222. AddressRange6: "2607::ffff/130",
  223. },
  224. errMessage: "Field validation for 'AddressRange6' failed on the 'cidrv6' tag",
  225. },
  226. {
  227. testname: "InvalidNetID",
  228. network: models.Network{
  229. NetID: "with spaces",
  230. },
  231. errMessage: "Field validation for 'NetID' failed on the 'netid_valid' tag",
  232. },
  233. {
  234. testname: "NetIDTooLong",
  235. network: models.Network{
  236. NetID: "LongNetIDName",
  237. },
  238. errMessage: "Field validation for 'NetID' failed on the 'max' tag",
  239. },
  240. {
  241. testname: "ListenPortTooLow",
  242. network: models.Network{
  243. NetID: "skynet",
  244. DefaultListenPort: 1023,
  245. },
  246. errMessage: "Field validation for 'DefaultListenPort' failed on the 'min' tag",
  247. },
  248. {
  249. testname: "ListenPortTooHigh",
  250. network: models.Network{
  251. NetID: "skynet",
  252. DefaultListenPort: 65536,
  253. },
  254. errMessage: "Field validation for 'DefaultListenPort' failed on the 'max' tag",
  255. },
  256. {
  257. testname: "KeepAliveTooBig",
  258. network: models.Network{
  259. NetID: "skynet",
  260. DefaultKeepalive: 1010,
  261. },
  262. errMessage: "Field validation for 'DefaultKeepalive' failed on the 'max' tag",
  263. },
  264. }
  265. for _, tc := range cases {
  266. t.Run(tc.testname, func(t *testing.T) {
  267. t.Log(tc.testname)
  268. network := models.Network(tc.network)
  269. network.SetDefaults()
  270. err := logic.ValidateNetwork(&network, false)
  271. assert.NotNil(t, err)
  272. assert.Contains(t, err.Error(), tc.errMessage) // test passes if err.Error() contains the expected errMessage.
  273. })
  274. }
  275. }
  276. func TestIpv6Network(t *testing.T) {
  277. //these seem to work but not sure it the tests are really testing the functionality
  278. initialize()
  279. os.Setenv("MASTER_KEY", "secretkey")
  280. deleteAllNetworks()
  281. createNet()
  282. createNetDualStack()
  283. network, err := logic.GetNetwork("skynet6")
  284. t.Run("Test Network Create IPv6", func(t *testing.T) {
  285. assert.Nil(t, err)
  286. assert.Equal(t, network.AddressRange6, "fde6:be04:fa5e:d076::/64")
  287. })
  288. node1 := createNodeWithParams("skynet6", "")
  289. createNetHost()
  290. nodeErr := logic.AssociateNodeToHost(node1, &netHost)
  291. t.Run("Test node on network IPv6", func(t *testing.T) {
  292. assert.Nil(t, nodeErr)
  293. assert.Equal(t, "fde6:be04:fa5e:d076::1", node1.Address6.IP.String())
  294. })
  295. }
  296. func deleteAllNetworks() {
  297. deleteAllNodes()
  298. nets, _ := logic.GetNetworks()
  299. for _, net := range nets {
  300. logic.DeleteNetwork(net.NetID)
  301. }
  302. }
  303. func initialize() {
  304. database.InitializeDatabase()
  305. createAdminUser()
  306. go logic.ManageZombies(context.Background())
  307. }
  308. func createAdminUser() {
  309. logic.CreateAdmin(&models.User{
  310. UserName: "admin",
  311. Password: "password",
  312. IsAdmin: true,
  313. Networks: []string{},
  314. Groups: []string{},
  315. })
  316. }
  317. func createNet() {
  318. var network models.Network
  319. network.NetID = "skynet"
  320. network.AddressRange = "10.0.0.1/24"
  321. _, err := logic.GetNetwork("skynet")
  322. if err != nil {
  323. logic.CreateNetwork(network)
  324. }
  325. }
  326. func createNetDualStack() {
  327. var network models.Network
  328. network.NetID = "skynet6"
  329. network.AddressRange = "10.1.2.0/24"
  330. network.AddressRange6 = "fde6:be04:fa5e:d076::/64"
  331. network.IsIPv4 = "yes"
  332. network.IsIPv6 = "yes"
  333. _, err := logic.GetNetwork("skynet6")
  334. if err != nil {
  335. logic.CreateNetwork(network)
  336. }
  337. }
  338. func createNetHost() {
  339. k, _ := wgtypes.ParseKey("DM5qhLAE20PG9BbfBCger+Ac9D2NDOwCtY1rbYDLf34=")
  340. netHost = models.Host{
  341. ID: uuid.New(),
  342. PublicKey: k.PublicKey(),
  343. HostPass: "password",
  344. OS: "linux",
  345. Name: "nethost",
  346. }
  347. _ = logic.CreateHost(&netHost)
  348. }