algorithms: export from google3

This commit is contained in:
Corentin Le Molgat
2025-05-16 14:13:06 +02:00
parent b28b0625f9
commit 17498776bf
4 changed files with 208 additions and 76 deletions

View File

@@ -97,6 +97,7 @@ cc_library(
deps = [
"@abseil-cpp//absl/algorithm:container",
"@abseil-cpp//absl/base",
"@abseil-cpp//absl/base:log_severity",
"@abseil-cpp//absl/log",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/numeric:bits",
@@ -118,14 +119,13 @@ cc_test(
"//ortools/base:dump_vars",
"//ortools/base:gmock_main",
"//ortools/base:mathutil",
"//ortools/base:timer",
"@abseil-cpp//absl/algorithm:container",
"@abseil-cpp//absl/log",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/numeric:bits",
"@abseil-cpp//absl/numeric:int128",
"@abseil-cpp//absl/random",
"@abseil-cpp//absl/random:bit_gen_ref",
"@abseil-cpp//absl/random:distributions",
"@abseil-cpp//absl/time",
"@abseil-cpp//absl/types:span",
"@com_google_benchmark//:benchmark",
],

View File

@@ -30,9 +30,6 @@
// But the worst-case performance of RadixSort() is much faster than the
// worst-case performance of std::sort().
// To be sure, you should benchmark your use case.
//
// TODO: it could be even faster than that when the values are in [0..N) for a
// known value N that's significantly lower than the max integer value.
#include <algorithm>
#include <cstddef>
@@ -45,8 +42,10 @@
#include "absl/algorithm/container.h"
#include "absl/base/casts.h"
#include "absl/base/log_severity.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/numeric/bits.h"
#include "absl/types/span.h"
namespace operations_research {
@@ -54,14 +53,24 @@ namespace operations_research {
// Sorts an array of int, double, or other numeric types. Up to ~10x faster than
// std::sort() when size ≥ 8k: go/radix-sort-bench. See file-level comment.
template <typename T>
void RadixSort(absl::Span<T> values);
void RadixSort(
absl::Span<T> values,
// ADVANCED USAGE: if you're sorting nonnegative integers, and suspect that
// their values use less bits than their full bit width, you may improve
// performance by setting `num_bits` to a lower value, for example
// NumBitsForZeroTo(max_value). It might even be faster to scan the values
// once just to do that, e.g., RadixSort(values,
// NumBitsForZeroTo(*absl::c_max_element(values)));
int num_bits = sizeof(T) * 8);
template <typename T>
int NumBitsForZeroTo(T max_value);
// ADVANCED USAGE: For power users who know which radix_width or num_passes
// they need, possibly differing from the canonical values used by RadixSort().
template <typename T, int radix_width, int num_passes>
void RadixSortTpl(absl::Span<T> values);
// TODO(user): Support arbitrary types with an int() or other numerical getter.
// TODO(user): Support the user providing already-allocated memory buffers
// for the radix counts and/or for the temporary vector<T> copy.
@@ -240,49 +249,101 @@ void RadixSortTpl(absl::Span<T> values) {
}
}
// TODO(user): Expose an API that takes the "max value" as argument, for
// users who want to take advantage of that knowledge to reduce the number of
// passes.
template <typename T>
void RadixSort(absl::Span<T> values) {
switch (sizeof(T)) {
case 1:
if (values.size() < 300) {
absl::c_sort(values);
int NumBitsForZeroTo(T max_value) {
if constexpr (!std::is_integral_v<T>) {
return sizeof(T) * 8;
} else {
RadixSortTpl<T, /*radix_width=*/8, /*num_passes=*/1>(values);
using U = std::make_unsigned_t<T>;
DCHECK_GE(max_value, 0);
return std::numeric_limits<U>::digits - absl::countl_zero<U>(max_value);
}
return;
case 2:
}
#ifdef NDEBUG
const bool DEBUG_MODE = false;
#else
const bool DEBUG_MODE = true;
#endif
template <typename T>
void RadixSort(absl::Span<T> values, int num_bits) {
// Debug-check that num_bits is valid w.r.t. the values given.
if constexpr (DEBUG_MODE) {
if constexpr (!std::is_integral_v<T>) {
DCHECK_EQ(num_bits, sizeof(T) * 8);
} else if (!values.empty()) {
auto minmax_it = absl::c_minmax_element(values);
const T min_val = *minmax_it.first;
const T max_val = *minmax_it.second;
if (num_bits == 0) {
DCHECK_EQ(max_val, 0);
} else {
using U = std::make_unsigned_t<T>;
// We only shift by num_bits - 1, to avoid to potentially shift by the
// entire bit width, which would be undefined behavior.
DCHECK_LE(static_cast<U>(max_val) >> (num_bits - 1), 1);
DCHECK_LE(static_cast<U>(min_val) >> (num_bits - 1), 1);
}
}
}
// This shortcut here is important to have early, guarded by as few "if"
// branches as possible, for the use case where the array is very small.
// For larger arrays below, the overhead of a few "if" is negligible.
if (values.size() < 300) {
absl::c_sort(values);
return;
}
// TODO(user): More complex decision tree, based on benchmarks. This one
// is already nice, but some cases can surely be optimized.
if (num_bits <= 16) {
if (num_bits <= 8) {
RadixSortTpl<T, /*radix_width=*/8, /*num_passes=*/1>(values);
} else {
RadixSortTpl<T, /*radix_width=*/8, /*num_passes=*/2>(values);
}
return;
case 4:
if (values.size() < 300) {
absl::c_sort(values);
} else if (values.size() < 1000) {
} else if (num_bits <= 32) { // num_bits ∈ [17..32]
if (values.size() < 1000) {
if (num_bits <= 24) {
RadixSortTpl<T, /*radix_width=*/8, /*num_passes=*/3>(values);
} else {
RadixSortTpl<T, /*radix_width=*/8, /*num_passes=*/4>(values);
}
} else if (values.size() < 2'500'000) {
if (num_bits <= 22) {
RadixSortTpl<T, /*radix_width=*/11, /*num_passes=*/2>(values);
} else {
RadixSortTpl<T, /*radix_width=*/11, /*num_passes=*/3>(values);
}
} else {
RadixSortTpl<T, /*radix_width=*/16, /*num_passes=*/2>(values);
}
return;
case 8:
} else if (num_bits <= 64) { // num_bits ∈ [33..64]
if (values.size() < 5000) {
absl::c_sort(values);
} else if (values.size() < 1'500'000) {
if (num_bits <= 33) {
RadixSortTpl<T, /*radix_width=*/11, /*num_passes=*/3>(values);
} else if (num_bits <= 44) {
RadixSortTpl<T, /*radix_width=*/11, /*num_passes=*/4>(values);
} else if (num_bits <= 55) {
RadixSortTpl<T, /*radix_width=*/11, /*num_passes=*/5>(values);
} else {
RadixSortTpl<T, /*radix_width=*/11, /*num_passes=*/6>(values);
}
} else {
if (num_bits <= 48) {
RadixSortTpl<T, /*radix_width=*/16, /*num_passes=*/3>(values);
} else {
RadixSortTpl<T, /*radix_width=*/16, /*num_passes=*/4>(values);
}
return;
}
} else {
LOG(DFATAL) << "RadixSort() called with unsupported value type";
absl::c_sort(values);
}
}
} // namespace operations_research

View File

@@ -13,7 +13,6 @@
#include "ortools/algorithms/radix_sort.h"
#include <bit>
#include <cmath>
#include <cstddef>
#include <cstdint>
@@ -25,6 +24,8 @@
#include "absl/algorithm/container.h"
#include "absl/log/log.h"
#include "absl/numeric/bits.h"
#include "absl/numeric/int128.h"
#include "absl/random/bit_gen_ref.h"
#include "absl/random/distributions.h"
#include "absl/random/random.h"
@@ -41,6 +42,28 @@ namespace {
using ::testing::ElementsAre;
using ::testing::IsEmpty;
template <typename T>
class NumBitsForZeroToTest : public ::testing::Test {};
TYPED_TEST_SUITE_P(NumBitsForZeroToTest);
TYPED_TEST_P(NumBitsForZeroToTest, CorrectnessStressTest) {
absl::BitGen rng;
constexpr int kNumTests = 1'000'000;
for (int test = 0; test < kNumTests; ++test) {
const TypeParam max_val = absl::LogUniform<TypeParam>(
rng, 0, std::numeric_limits<TypeParam>::max());
const int num_bits = NumBitsForZeroTo(max_val);
EXPECT_LE(absl::int128{max_val}, absl::int128{1} << num_bits);
}
}
REGISTER_TYPED_TEST_SUITE_P(NumBitsForZeroToTest, CorrectnessStressTest);
using IntTypes = ::testing::Types<int, uint32_t, int64_t, uint64_t, int16_t,
uint16_t, int8_t, uint8_t>;
INSTANTIATE_TYPED_TEST_SUITE_P(My, NumBitsForZeroToTest, IntTypes);
// If T is a floating-point type, ignores min_val / max_val.
template <typename T>
std::vector<T> RandomValues(absl::BitGenRef rng, size_t size,
@@ -103,6 +126,9 @@ TYPED_TEST_P(RadixSortTest, RandomizedCorrectnessTestAgainstStdSortSmallSizes) {
// Will we use the standard RadixSort() or the RadixSortTpl<>() variant?
const bool use_main_radix_sort = absl::Bernoulli(rng, 0.5);
const bool use_num_bits = std::is_integral_v<TypeParam> &&
use_main_radix_sort && !allow_negative &&
absl::Bernoulli(rng, 0.5);
// We potentially test the "power usage" of calling RadixSortTpl<> with
// radix_width * num_passes < num_bits(TypeParam), when the actual values
@@ -128,7 +154,12 @@ TYPED_TEST_P(RadixSortTest, RandomizedCorrectnessTestAgainstStdSortSmallSizes) {
int radix_width = -1;
int num_passes = -1;
if (use_main_radix_sort) {
if (use_num_bits) {
RadixSort(absl::MakeSpan(sorted_values),
NumBitsForZeroTo(max_abs_val.value()));
} else {
RadixSort(absl::MakeSpan(sorted_values));
}
} else {
// Draw random (radix_width, num_passes) pairs until we get a valid one.
constexpr int kMaxNumPasses = 8;
@@ -147,8 +178,8 @@ TYPED_TEST_P(RadixSortTest, RandomizedCorrectnessTestAgainstStdSortSmallSizes) {
absl::c_sort(expected_values);
ASSERT_TRUE(sorted_values == expected_values)
<< DUMP_VARS(test, use_main_radix_sort, radix_width, num_passes, size,
allow_negative, val_bits, max_abs_val, unsorted_values,
sorted_values, expected_values);
allow_negative, use_num_bits, val_bits, max_abs_val,
unsorted_values, sorted_values, expected_values);
}
}
@@ -205,10 +236,20 @@ TYPED_TEST_P(RadixSortTest, RandomizedCorrectnessTestAgainstStdSortLargeSizes) {
std::vector<TypeParam> values =
RandomValues<TypeParam>(rng, size, allow_negative, /*max_abs_val=*/{});
const bool use_main_radix_sort = absl::Bernoulli(rng, 0.5);
const bool use_num_bits = std::is_integral_v<TypeParam> &&
use_main_radix_sort && !allow_negative &&
absl::Bernoulli(rng, 0.5);
int radix_width = -1;
int num_passes = -1;
if (use_main_radix_sort) {
if (use_num_bits) {
RadixSort(
absl::MakeSpan(values),
NumBitsForZeroTo(size == 0 ? 1 : *absl::c_max_element(values)));
} else {
RadixSort(absl::MakeSpan(values));
}
} else {
radix_width = RandomRadixWidth(rng);
num_passes =
@@ -218,7 +259,7 @@ TYPED_TEST_P(RadixSortTest, RandomizedCorrectnessTestAgainstStdSortLargeSizes) {
// Contrary to the 'small' stress test, we don't log the data upon failure.
ASSERT_TRUE(absl::c_is_sorted(values))
<< DUMP_VARS(test, use_main_radix_sort, radix_width, num_passes, size,
allow_negative);
allow_negative, use_num_bits);
}
}
@@ -237,13 +278,16 @@ template <typename T>
std::vector<T> SortedValues(size_t size) {
const T offset = std::is_signed_v<T> ? -static_cast<T>(size) / 2 : T{0};
std::vector<T> values(size);
for (size_t i = 0; i < size; ++i) values[i] = i = offset;
for (size_t i = 0; i < size; ++i) values[i] = i + offset;
return values;
}
enum Algo {
kStdSort,
kRadixSort,
kRadixSortTpl,
kRadixSortKnownMax,
kRadixSortComputeMax,
kRadixSortWorst,
};
enum InputOrder {
@@ -280,9 +324,22 @@ void BM_Sort(benchmark::State& state) {
to_sort = values;
if constexpr (algo == kStdSort) {
absl::c_sort(to_sort);
} else {
} else if constexpr (algo == kRadixSortTpl) {
absl::Span<T> span{to_sort.data(), to_sort.size()};
RadixSortTpl<T, radix_width, num_passes>(span);
} else if constexpr (algo == kRadixSortKnownMax) {
absl::Span<T> span = absl::MakeSpan(to_sort);
RadixSort(span, NumBitsForZeroTo(
max_abs_val.value_or(std::numeric_limits<T>::max())));
} else if constexpr (algo == kRadixSortComputeMax) {
absl::Span<T> span{to_sort.data(), to_sort.size()};
RadixSort(span, NumBitsForZeroTo(
size == 0 ? 1 : *absl::c_max_element(to_sort)));
} else if constexpr (algo == kRadixSortWorst) {
absl::Span<T> span{to_sort.data(), to_sort.size()};
RadixSort(span);
} else {
LOG(DFATAL) << "Unsupported algo: " << algo;
}
benchmark::DoNotOptimize(to_sort);
}
@@ -317,114 +374,127 @@ BENCHMARK(BM_Sort<kStdSort, int, kAlmostSorted, 1, 1>)
->RangeMultiplier(2)
->Range(1, 128 << 10);
BENCHMARK(BM_Sort<kRadixSort, uint32_t, kRandom, /*radix_width=*/8,
BENCHMARK(BM_Sort<kRadixSortTpl, uint32_t, kRandom, /*radix_width=*/8,
/*num_passes=*/4>)
->RangeMultiplier(2)
->Range(16, 2048);
BENCHMARK(BM_Sort<kRadixSort, uint32_t, kRandom, /*radix_width=*/11,
BENCHMARK(BM_Sort<kRadixSortTpl, uint32_t, kRandom, /*radix_width=*/11,
/*num_passes=*/3>)
->RangeMultiplier(2)
->Range(256, 32 << 20);
BENCHMARK(BM_Sort<kRadixSort, uint32_t, kRandom, /*radix_width=*/16,
BENCHMARK(BM_Sort<kRadixSortTpl, uint32_t, kRandom, /*radix_width=*/16,
/*num_passes=*/2>)
->RangeMultiplier(2)
->Range(128 << 10, 32 << 20);
BENCHMARK(BM_Sort<kRadixSort, int, kRandom, /*radix_width=*/8,
BENCHMARK(BM_Sort<kRadixSortTpl, int, kRandom, /*radix_width=*/8,
/*num_passes=*/4>)
->RangeMultiplier(2)
->Range(16, 2048);
BENCHMARK(BM_Sort<kRadixSort, int, kRandom, /*radix_width=*/11,
BENCHMARK(BM_Sort<kRadixSortTpl, int, kRandom, /*radix_width=*/11,
/*num_passes=*/3>)
->RangeMultiplier(2)
->Range(256, 32 << 20);
BENCHMARK(BM_Sort<kRadixSort, int, kRandom, /*radix_width=*/16,
BENCHMARK(BM_Sort<kRadixSortTpl, int, kRandom, /*radix_width=*/16,
/*num_passes=*/2>)
->RangeMultiplier(2)
->Range(128 << 10, 32 << 20);
BENCHMARK(BM_Sort<kRadixSort, float, kRandom, /*radix_width=*/8,
/*num_passes=*/4>)
BENCHMARK(BM_Sort<kRadixSortKnownMax, int, kRandom, /*radix_width=*/16,
/*num_passes=*/2>)
->RangeMultiplier(2)
->Range(16, 2048);
BENCHMARK(BM_Sort<kRadixSort, float, kRandom, /*radix_width=*/11,
/*num_passes=*/3>)
->Range(128 << 10, 32 << 20);
BENCHMARK(BM_Sort<kRadixSortComputeMax, int, kRandom, /*radix_width=*/16,
/*num_passes=*/2>)
->RangeMultiplier(2)
->Range(256, 32 << 20);
BENCHMARK(BM_Sort<kRadixSort, float, kRandom, /*radix_width=*/16,
->Range(128 << 10, 32 << 20);
BENCHMARK(BM_Sort<kRadixSortWorst, int, kRandom, /*radix_width=*/16,
/*num_passes=*/2>)
->RangeMultiplier(2)
->Range(128 << 10, 32 << 20);
BENCHMARK(BM_Sort<kRadixSort, uint64_t, kRandom, /*radix_width=*/11,
BENCHMARK(BM_Sort<kRadixSortTpl, float, kRandom, /*radix_width=*/8,
/*num_passes=*/4>)
->RangeMultiplier(2)
->Range(16, 2048);
BENCHMARK(BM_Sort<kRadixSortTpl, float, kRandom, /*radix_width=*/11,
/*num_passes=*/3>)
->RangeMultiplier(2)
->Range(256, 32 << 20);
BENCHMARK(BM_Sort<kRadixSortTpl, float, kRandom, /*radix_width=*/16,
/*num_passes=*/2>)
->RangeMultiplier(2)
->Range(128 << 10, 32 << 20);
BENCHMARK(BM_Sort<kRadixSortTpl, uint64_t, kRandom, /*radix_width=*/11,
/*num_passes=*/6>)
->RangeMultiplier(2)
->Range(2048, 8 << 20)
->Arg(32 << 20)
->Arg(128 << 20);
BENCHMARK(BM_Sort<kRadixSort, uint64_t, kRandom, /*radix_width=*/13,
BENCHMARK(BM_Sort<kRadixSortTpl, uint64_t, kRandom, /*radix_width=*/13,
/*num_passes=*/5>)
->RangeMultiplier(2)
->Range(2048, 8 << 20)
->Arg(32 << 20)
->Arg(128 << 20);
BENCHMARK(BM_Sort<kRadixSort, uint64_t, kRandom, /*radix_width=*/16,
BENCHMARK(BM_Sort<kRadixSortTpl, uint64_t, kRandom, /*radix_width=*/16,
/*num_passes=*/4>)
->RangeMultiplier(2)
->Range(128 << 10, 8 << 20)
->Arg(32 << 20)
->Arg(128 << 20);
BENCHMARK(BM_Sort<kRadixSort, uint64_t, kRandom, /*radix_width=*/22,
BENCHMARK(BM_Sort<kRadixSortTpl, uint64_t, kRandom, /*radix_width=*/22,
/*num_passes=*/3>)
->RangeMultiplier(2)
->Range(128 << 10, 8 << 20)
->Arg(32 << 20)
->Arg(128 << 20);
BENCHMARK(BM_Sort<kRadixSort, int64_t, kRandom, /*radix_width=*/11,
BENCHMARK(BM_Sort<kRadixSortTpl, int64_t, kRandom, /*radix_width=*/11,
/*num_passes=*/6>)
->RangeMultiplier(2)
->Range(2048, 8 << 20)
->Arg(32 << 20)
->Arg(128 << 20);
BENCHMARK(BM_Sort<kRadixSort, int64_t, kRandom, /*radix_width=*/13,
BENCHMARK(BM_Sort<kRadixSortTpl, int64_t, kRandom, /*radix_width=*/13,
/*num_passes=*/5>)
->RangeMultiplier(2)
->Range(2048, 8 << 20)
->Arg(32 << 20)
->Arg(128 << 20);
BENCHMARK(BM_Sort<kRadixSort, int64_t, kRandom, /*radix_width=*/16,
BENCHMARK(BM_Sort<kRadixSortTpl, int64_t, kRandom, /*radix_width=*/16,
/*num_passes=*/4>)
->RangeMultiplier(2)
->Range(128 << 10, 8 << 20)
->Arg(32 << 20)
->Arg(128 << 20);
BENCHMARK(BM_Sort<kRadixSort, int64_t, kRandom, /*radix_width=*/22,
BENCHMARK(BM_Sort<kRadixSortTpl, int64_t, kRandom, /*radix_width=*/22,
/*num_passes=*/3>)
->RangeMultiplier(2)
->Range(128 << 10, 8 << 20)
->Arg(32 << 20)
->Arg(128 << 20);
BENCHMARK(BM_Sort<kRadixSort, double, kRandom, /*radix_width=*/11,
BENCHMARK(BM_Sort<kRadixSortTpl, double, kRandom, /*radix_width=*/11,
/*num_passes=*/6>)
->RangeMultiplier(2)
->Range(2048, 8 << 20)
->Arg(32 << 20)
->Arg(128 << 20);
BENCHMARK(BM_Sort<kRadixSort, double, kRandom, /*radix_width=*/13,
BENCHMARK(BM_Sort<kRadixSortTpl, double, kRandom, /*radix_width=*/13,
/*num_passes=*/5>)
->RangeMultiplier(2)
->Range(2048, 8 << 20)
->Arg(32 << 20)
->Arg(128 << 20);
BENCHMARK(BM_Sort<kRadixSort, double, kRandom, /*radix_width=*/16,
BENCHMARK(BM_Sort<kRadixSortTpl, double, kRandom, /*radix_width=*/16,
/*num_passes=*/4>)
->RangeMultiplier(2)
->Range(128 << 10, 8 << 20)
->Arg(32 << 20)
->Arg(128 << 20);
BENCHMARK(BM_Sort<kRadixSort, double, kRandom, /*radix_width=*/22,
BENCHMARK(BM_Sort<kRadixSortTpl, double, kRandom, /*radix_width=*/22,
/*num_passes=*/3>)
->RangeMultiplier(2)
->Range(128 << 10, 8 << 20)

View File

@@ -62,6 +62,7 @@
#define DUMP_FOR_EACH_N9(F, a, ...) F(a) DUMP_FOR_EACH_N8(F, __VA_ARGS__)
#define DUMP_FOR_EACH_N10(F, a, ...) F(a) DUMP_FOR_EACH_N9(F, __VA_ARGS__)
#define DUMP_FOR_EACH_N11(F, a, ...) F(a) DUMP_FOR_EACH_N10(F, __VA_ARGS__)
#define DUMP_FOR_EACH_N12(F, a, ...) F(a) DUMP_FOR_EACH_N11(F, __VA_ARGS__)
#define DUMP_CONCATENATE(x, y) x##y
#define DUMP_FOR_EACH_(N, F, ...) \
@@ -69,8 +70,8 @@
#define DUMP_NARG(...) DUMP_NARG_(__VA_OPT__(__VA_ARGS__, ) DUMP_RSEQ_N())
#define DUMP_NARG_(...) DUMP_ARG_N(__VA_ARGS__)
#define DUMP_ARG_N(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, N, ...) N
#define DUMP_RSEQ_N() 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
#define DUMP_ARG_N(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, N, ...) N
#define DUMP_RSEQ_N() 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
#define DUMP_FOR_EACH(F, ...) \
DUMP_FOR_EACH_(DUMP_NARG(__VA_ARGS__), F __VA_OPT__(, __VA_ARGS__))