fork(1) download
  1. #include <iostream>
  2. #include <vector>
  3. #include <algorithm>
  4. #include <chrono>
  5. #include <random>
  6. #include <type_traits>
  7.  
  8. // =====================================================================
  9. // Custom Template Radix Sort - $O(N)$ Complexity
  10. // Optimally designed for 32-bit integers (signed and unsigned)
  11. // =====================================================================
  12. template<typename RandomIt>
  13. void custom_radix_sort(RandomIt first, RandomIt last) {
  14. using T = typename std::iterator_traits<RandomIt>::value_type;
  15.  
  16. // We enforce 32-bit integers for this specific high-performance implementation
  17. //static_assert(std::is_integral_v<T> && sizeof(T) == 4,
  18. // "This template is optimized for 32-bit integers.");
  19.  
  20. size_t n = std::distance(first, last);
  21. if (n <= 1) return;
  22.  
  23. // Buffer for scatter phase
  24. std::vector<T> buffer(n);
  25.  
  26. // Raw pointers for maximum memory speed
  27. T* src = &(*first);
  28. T* dst = buffer.data();
  29.  
  30. // 1. One-Pass Histogram Generation (Cache-Friendly Optimization)
  31. // We count the byte frequencies for all 4 bytes in a single pass.
  32. uint32_t counts[4][256] = {0};
  33.  
  34. for (size_t i = 0; i < n; ++i) {
  35. // Cast to unsigned to prevent arithmetic shift (sign extension) bugs
  36. uint32_t val = static_cast<uint32_t>(src[i]);
  37.  
  38. counts[0][val & 0xFF]++;
  39. counts[1][(val >> 8) & 0xFF]++;
  40. counts[2][(val >> 16) & 0xFF]++;
  41.  
  42. // Handle negative numbers correctly by flipping the sign bit
  43. unsigned char c3 = (val >> 24) & 0xFF;
  44. c3 ^= 128;
  45. counts[3][c3]++;
  46. }
  47.  
  48. // 2. Radix Passes (Process each byte)
  49. for (int byte = 0; byte < 4; ++byte) {
  50. // Calculate prefix sums to find the exact array index for each item
  51. uint32_t pos[256];
  52. pos[0] = 0;
  53. for (int i = 1; i < 256; ++i) {
  54. pos[i] = pos[i - 1] + counts[byte][i - 1];
  55. }
  56.  
  57. // Scatter elements into the destination buffer
  58. int shift = byte * 8;
  59. for (size_t i = 0; i < n; ++i) {
  60. uint32_t val = static_cast<uint32_t>(src[i]);
  61. unsigned char c = (val >> shift) & 0xFF;
  62.  
  63. if (byte == 3) c ^= 128; // Flip highest bit for sorting signed ints
  64.  
  65. dst[pos[c]++] = src[i];
  66. }
  67.  
  68. // Swap src and dst pointers.
  69. // Because we do exactly 4 passes, src will cleanly end up pointing
  70. // back to the original array (`first`), requiring zero final copies!
  71. std::swap(src, dst);
  72. }
  73. }
  74.  
  75. // =====================================================================
  76. // Benchmark Engine
  77. // =====================================================================
  78. int main() {
  79. const int SIZE = 1000000;
  80.  
  81. std::cout << "Generating " << SIZE << " random elements...\n";
  82. std::vector<int> data_std(SIZE);
  83.  
  84. // Generate chaotic random data (including negative numbers)
  85. std::mt19937 rng(42);
  86. std::uniform_int_distribution<int> dist(-1000000, 1000000);
  87. for(int i = 0; i < SIZE; ++i) {
  88. data_std[i] = dist(rng);
  89. }
  90.  
  91. // Duplicate data
  92. std::vector<int> data_custom = data_std;
  93.  
  94. // ---------------------------------------------------------
  95. // Benchmark 1: std::sort (O(N log N))
  96. // ---------------------------------------------------------
  97. auto start1 = std::chrono::high_resolution_clock::now();
  98. std::sort(data_std.begin(), data_std.end());
  99. auto end1 = std::chrono::high_resolution_clock::now();
  100. std::chrono::duration<double, std::milli> time_std = end1 - start1;
  101.  
  102. // ---------------------------------------------------------
  103. // Benchmark 2: Custom Radix Sort (O(N))
  104. // ---------------------------------------------------------
  105. auto start2 = std::chrono::high_resolution_clock::now();
  106. custom_radix_sort(data_custom.begin(), data_custom.end());
  107. auto end2 = std::chrono::high_resolution_clock::now();
  108. std::chrono::duration<double, std::milli> time_custom = end2 - start2;
  109.  
  110. // ---------------------------------------------------------
  111. // Results
  112. // ---------------------------------------------------------
  113. std::cout << "\n--- Benchmark Results (" << SIZE << " elements) ---\n";
  114. std::cout << "std::sort time: " << time_std.count() << " ms\n";
  115. std::cout << "custom_radix_sort: " << time_custom.count() << " ms\n";
  116.  
  117. // Verify
  118. bool correct = std::is_sorted(data_custom.begin(), data_custom.end());
  119. std::cout << "\nCustom sort is correct: " << (correct ? "True" : "False") << "\n";
  120.  
  121. // Speed Math
  122. if (time_custom.count() < time_std.count()) {
  123. std::cout << "🏆 RESULT: custom_radix_sort won!\n";
  124. std::cout << "It is " << (time_std.count() / time_custom.count())
  125. << "x faster than std::sort.\n";
  126. }
  127.  
  128. return 0;
  129. }
Success #stdin #stdout 0.1s 14912KB
stdin
Standard input is empty
stdout
Generating 1000000 random elements...

--- Benchmark Results (1000000 elements) ---
std::sort time:          58.2365 ms
custom_radix_sort:       9.32814 ms

Custom sort is correct:  True
🏆 RESULT: custom_radix_sort won!
It is 6.2431x faster than std::sort.