barrier.odin 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. package sync
  2. // A barrier enabling multiple threads to synchronize the beginning of some computation
  3. /*
  4. * Example:
  5. *
  6. * package example
  7. *
  8. * import "core:fmt"
  9. * import "core:sync"
  10. * import "core:thread"
  11. *
  12. * barrier := &sync.Barrier{};
  13. *
  14. * main :: proc() {
  15. * fmt.println("Start");
  16. *
  17. * THREAD_COUNT :: 4;
  18. * threads: [THREAD_COUNT]^thread.Thread;
  19. *
  20. * sync.barrier_init(barrier, THREAD_COUNT);
  21. * defer sync.barrier_destroy(barrier);
  22. *
  23. *
  24. * for _, i in threads {
  25. * threads[i] = thread.create_and_start(proc(t: ^thread.Thread) {
  26. * // Same messages will be printed together but without any interleaving
  27. * fmt.println("Getting ready!");
  28. * sync.barrier_wait(barrier);
  29. * fmt.println("Off their marks they go!");
  30. * });
  31. * }
  32. *
  33. * for t in threads {
  34. * thread.destroy(t); // join and free thread
  35. * }
  36. * fmt.println("Finished");
  37. * }
  38. *
  39. */
  40. Barrier :: struct {
  41. mutex: Blocking_Mutex,
  42. cond: Condition,
  43. index: int,
  44. generation_id: int,
  45. thread_count: int,
  46. }
  47. barrier_init :: proc(b: ^Barrier, thread_count: int) {
  48. blocking_mutex_init(&b.mutex);
  49. condition_init(&b.cond, &b.mutex);
  50. b.index = 0;
  51. b.generation_id = 0;
  52. b.thread_count = thread_count;
  53. }
  54. barrier_destroy :: proc(b: ^Barrier) {
  55. blocking_mutex_destroy(&b.mutex);
  56. condition_destroy(&b.cond);
  57. }
  58. // Block the current thread until all threads have rendezvoused
  59. // Barrier can be reused after all threads rendezvoused once, and can be used continuously
  60. barrier_wait :: proc(b: ^Barrier) -> (is_leader: bool) {
  61. blocking_mutex_lock(&b.mutex);
  62. defer blocking_mutex_unlock(&b.mutex);
  63. local_gen := b.generation_id;
  64. b.index += 1;
  65. if b.index < b.thread_count {
  66. for local_gen == b.generation_id && b.index < b.thread_count {
  67. condition_wait_for(&b.cond);
  68. }
  69. return false;
  70. }
  71. b.index = 0;
  72. b.generation_id += 1;
  73. condition_broadcast(&b.cond);
  74. return true;
  75. }