#include <iostream>
#include <vector>
#include <algorithm>
#include <chrono>
#include <random>
#include <type_traits>
// =====================================================================
// Custom Template Radix Sort - $O(N)$ Complexity
// Optimally designed for 32-bit integers (signed and unsigned)
// =====================================================================
template<typename RandomIt>
void custom_radix_sort(RandomIt first, RandomIt last) {
using T = typename std::iterator_traits<RandomIt>::value_type;
// We enforce 32-bit integers for this specific high-performance implementation
//static_assert(std::is_integral_v<T> && sizeof(T) == 4,
// "This template is optimized for 32-bit integers.");
size_t n = std::distance(first, last);
if (n <= 1) return;
// Buffer for scatter phase
std::vector<T> buffer(n);
// Raw pointers for maximum memory speed
T* src = &(*first);
T* dst = buffer.data();
// 1. One-Pass Histogram Generation (Cache-Friendly Optimization)
// We count the byte frequencies for all 4 bytes in a single pass.
uint32_t counts[4][256] = {0};
for (size_t i = 0; i < n; ++i) {
// Cast to unsigned to prevent arithmetic shift (sign extension) bugs
uint32_t val = static_cast<uint32_t>(src[i]);
counts[0][val & 0xFF]++;
counts[1][(val >> 8) & 0xFF]++;
counts[2][(val >> 16) & 0xFF]++;
// Handle negative numbers correctly by flipping the sign bit
unsigned char c3 = (val >> 24) & 0xFF;
c3 ^= 128;
counts[3][c3]++;
}
// 2. Radix Passes (Process each byte)
for (int byte = 0; byte < 4; ++byte) {
// Calculate prefix sums to find the exact array index for each item
uint32_t pos[256];
pos[0] = 0;
for (int i = 1; i < 256; ++i) {
pos[i] = pos[i - 1] + counts[byte][i - 1];
}
// Scatter elements into the destination buffer
int shift = byte * 8;
for (size_t i = 0; i < n; ++i) {
uint32_t val = static_cast<uint32_t>(src[i]);
unsigned char c = (val >> shift) & 0xFF;
if (byte == 3) c ^= 128; // Flip highest bit for sorting signed ints
dst[pos[c]++] = src[i];
}
// Swap src and dst pointers.
// Because we do exactly 4 passes, src will cleanly end up pointing
// back to the original array (`first`), requiring zero final copies!
std::swap(src, dst);
}
}
// =====================================================================
// Benchmark Engine
// =====================================================================
int main() {
const int SIZE = 100000;
std::cout << "Generating " << SIZE << " random elements...\n";
std::vector<int> data_std(SIZE);
// Generate chaotic random data (including negative numbers)
std::mt19937 rng(42);
std::uniform_int_distribution<int> dist(-1000000, 1000000);
for(int i = 0; i < SIZE; ++i) {
data_std[i] = dist(rng);
}
// Duplicate data
std::vector<int> data_custom = data_std;
// ---------------------------------------------------------
// Benchmark 1: std::sort (O(N log N))
// ---------------------------------------------------------
auto start1 = std::chrono::high_resolution_clock::now();
std::sort(data_std.begin(), data_std.end());
auto end1 = std::chrono::high_resolution_clock::now();
std::chrono::duration<double, std::milli> time_std = end1 - start1;
// ---------------------------------------------------------
// Benchmark 2: Custom Radix Sort (O(N))
// ---------------------------------------------------------
auto start2 = std::chrono::high_resolution_clock::now();
custom_radix_sort(data_custom.begin(), data_custom.end());
auto end2 = std::chrono::high_resolution_clock::now();
std::chrono::duration<double, std::milli> time_custom = end2 - start2;
// ---------------------------------------------------------
// Results
// ---------------------------------------------------------
std::cout << "\n--- Benchmark Results (" << SIZE << " elements) ---\n";
std::cout << "std::sort time: " << time_std.count() << " ms\n";
std::cout << "custom_radix_sort: " << time_custom.count() << " ms\n";
// Verify
bool correct = std::is_sorted(data_custom.begin(), data_custom.end());
std::cout << "\nCustom sort is correct: " << (correct ? "True" : "False") << "\n";
// Speed Math
if (time_custom.count() < time_std.count()) {
std::cout << "🏆 RESULT: custom_radix_sort won!\n";
std::cout << "It is " << (time_std.count() / time_custom.count())
<< "x faster than std::sort.\n";
}
return 0;
}