assert.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. package test
  2. import (
  3. "fmt"
  4. "net/netip"
  5. "reflect"
  6. "testing"
  7. "time"
  8. "unsafe"
  9. "github.com/stretchr/testify/assert"
  10. )
  11. // AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
  12. // There is currently a special case for `time.loc` (as this code traverses into unexported fields)
  13. func AssertDeepCopyEqual(t *testing.T, a any, b any) {
  14. v1 := reflect.ValueOf(a)
  15. v2 := reflect.ValueOf(b)
  16. if !assert.Equal(t, v1.Type(), v2.Type()) {
  17. return
  18. }
  19. traverseDeepCopy(t, v1, v2, v1.Type().String())
  20. }
  21. func traverseDeepCopy(t *testing.T, v1 reflect.Value, v2 reflect.Value, name string) bool {
  22. if v1.Type() == v2.Type() && v1.Type() == reflect.TypeOf(netip.Addr{}) {
  23. // Ignore netip.Addr types since they reuse an interned global value
  24. return false
  25. }
  26. switch v1.Kind() {
  27. case reflect.Array:
  28. for i := 0; i < v1.Len(); i++ {
  29. if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
  30. return false
  31. }
  32. }
  33. return true
  34. case reflect.Slice:
  35. if v1.IsNil() || v2.IsNil() {
  36. return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil %+v, %+v", name, v1, v2)
  37. }
  38. if !assert.Equal(t, v1.Len(), v2.Len(), "%s did not have the same length", name) {
  39. return false
  40. }
  41. // A slice with cap 0
  42. if v1.Cap() != 0 && !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same slice %v == %v", name, v1.Pointer(), v2.Pointer()) {
  43. return false
  44. }
  45. v1c := v1.Cap()
  46. v2c := v2.Cap()
  47. if v1c > 0 && v2c > 0 && v1.Slice(0, v1c).Slice(v1c-1, v1c-1).Pointer() == v2.Slice(0, v2c).Slice(v2c-1, v2c-1).Pointer() {
  48. return assert.Fail(t, "", "%s share some underlying memory", name)
  49. }
  50. for i := 0; i < v1.Len(); i++ {
  51. if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
  52. return false
  53. }
  54. }
  55. return true
  56. case reflect.Interface:
  57. if v1.IsNil() || v2.IsNil() {
  58. return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
  59. }
  60. return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
  61. case reflect.Ptr:
  62. local := reflect.ValueOf(time.Local).Pointer()
  63. if local == v1.Pointer() && local == v2.Pointer() {
  64. return true
  65. }
  66. if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s points to the same memory", name) {
  67. return false
  68. }
  69. return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
  70. case reflect.Struct:
  71. for i, n := 0, v1.NumField(); i < n; i++ {
  72. if !traverseDeepCopy(t, v1.Field(i), v2.Field(i), name+"."+v1.Type().Field(i).Name) {
  73. return false
  74. }
  75. }
  76. return true
  77. case reflect.Map:
  78. if v1.IsNil() || v2.IsNil() {
  79. return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
  80. }
  81. if !assert.Equal(t, v1.Len(), v2.Len(), "%s are not the same length", name) {
  82. return false
  83. }
  84. if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same memory", name) {
  85. return false
  86. }
  87. for _, k := range v1.MapKeys() {
  88. val1 := v1.MapIndex(k)
  89. val2 := v2.MapIndex(k)
  90. if !assert.True(t, val1.IsValid(), "%s is an invalid key in %s", k, name) {
  91. return false
  92. }
  93. if !assert.True(t, val2.IsValid(), "%s is an invalid key in %s", k, name) {
  94. return false
  95. }
  96. if !traverseDeepCopy(t, val1, val2, name+fmt.Sprintf("%s[%s]", name, k)) {
  97. return false
  98. }
  99. }
  100. return true
  101. default:
  102. if v1.CanInterface() && v2.CanInterface() {
  103. return assert.Equal(t, v1.Interface(), v2.Interface(), "%s was not equal", name)
  104. }
  105. e1 := reflect.NewAt(v1.Type(), unsafe.Pointer(v1.UnsafeAddr())).Elem().Interface()
  106. e2 := reflect.NewAt(v2.Type(), unsafe.Pointer(v2.UnsafeAddr())).Elem().Interface()
  107. return assert.Equal(t, e1, e2, "%s (unexported) was not equal", name)
  108. }
  109. }