题目

https://www.acwing.com/problem/content/description/787/

算法思想

待排序数组 arr,下标 i 到 j。

我们选择 i 到 j 之间的任意一个数据为 pivot(分区点)。之后开始遍历 i 到 j 之间的数据,小于等于 pivot 的数据放在左边,大于等于 pivot 的数据放在右边。经过一系列步骤之后,数组被分成了两部分,前面一部分都是小于等于 pivot 的,后面一部分是大于等于 pivot。

根据分治、递归的处理思想,我们可以一直处理,直到区间缩小为 1,就说明数据都有序了。

解法一

#include <iostream>
#include <algorithm>

const int N = 100010;

int arr[N];

void quick_sort(int l, int r) {
    if(l >= r) return;
    
    int i = l - 1, j = r + 1, x = arr[l+((r-l)>>1)];
    
    while(i < j) {
        do { ++i; } while(arr[i] < x);
        do { --j; } while(arr[j] > x);
        
        if(i < j) std::swap(arr[i], arr[j]);
    }
    
    quick_sort(l, j);
    quick_sort(j + 1, r);
}

int main() {
    int n;
    std::cin >> n;
    
    for(int i = 0; i < n; ++i)
        std::cin >> arr[i];
    
    quick_sort(0, n - 1);
    
    for(int i = 0; i < n; ++i)
        std::cout << arr[i] << " ";
    
    std::cout << std::endl;
    
    return 0;
}

我们通过样例来模拟一下整个流程:

证明

我们用循环不变式的方式来证明一下算法的正确性,即保证 arr[l...j] <= xarr[j+1...r] >= x

循环不变式:arr[l...i] <= xarr[j...r] >= x

初始情况i = l - 1j = r + 1,因为 arr[l...l-1]arr[r+1...r] 均为空,所以条件成立;

保持:假设某轮循环开始前,循环不变式成立,arr[l...i-1] <= xarr[j+1...r] >= x。执行当前循环体

do { ++i; } while(arr[i] < x);   // 执行之后,arr[i] >= x
do { --j; } while(arr[j] > x);   // 执行之后,arr[j] <= x
if(i < j) swap(arr[i], arr[j]);  // 执行之后,arr[i] <= x,arr[j] >= x

因此,arr[l...i] <= xarr[j...r] >= x

结束

正常情况下,结束时是 arr[l...i] <= xarr[j...r] >= xi >= j,显而易见 arr[l...j] <= xarr[j+1...r] >= x。但是,实际情况是,最后一轮中

// i >= j
if(i < j) swap(arr[i], arr[j])

这句不会执行,因此只能保证 arr[l...i-1] <= xarr[j+1...r] >= x ,同时 arr[i] >= xarr[j] <= xi >= j。由 arr[l...i-1] <= xarr[i] >= xarr[j] <= x 可以得到 arr[l...j] <= x

结论 arr[l...j] <= xarr[j+1...r] >= x

附加

在循环完成后,我们还有一个细节不能忽视,就是需要保证 j >= l, j <= r - 1 ,因为只有这样,才不会陷入无限循环。

证明: 假设 j = r,表明 do { --j; } while(arr[j] > x); 只执行了一次,即 arr[i] >= xi >= j, arr[j] <= x,可以得到 i = j = r,这与 x = arr[l + ((r-l) >> 1)] 矛盾,因此 j <= r - 1

假设 j < l,即 arr[l...r] > x,与实际情况相矛盾,因此可证明 j >= l, j <= r - 1

证明参考:

https://www.acwing.com/solution/content/16777/

解法二

#include <iostream>

const int N = 100010;

int arr[N];

void quick_sort(int l, int r) {
    if(l >= r) return;
    
    int i = l - 1, j = r + 1, x = arr[l+((r-l+1)>>1)];
    
    while(i < j) {
        do { ++i; } while(x > arr[i]);
        do { --j; } while(x < arr[j]);
        
        if(i < j) std::swap(arr[i], arr[j]);
    }
    
    quick_sort(l, i-1);    // 与解法一证明相似
    quick_sort(i, r);
    
    return;
}

int main() {
    int n;
    std::cin >> n;
    
    for(int i = 0; i < n; ++i)
        std::cin >> arr[i];
        
    quick_sort(0, n - 1);
    
    for(int i = 0; i < n; ++i)
        std::cout << arr[i] << " ";
    
    std::cout << std::endl;
    
    return 0;
}

x = arr[l+((r-l+1)>>1)] 这里这么处理的原因是不能让 x = arr[l],因为在这种情况下,当 arr[i+1...r] > x 会导致i = l ,从而陷入无线递归的过程。解法一同理,不能让 x = arr[r]

解法三

#include <iostream>
#include <algorithm>

const int N = 100010;

int arr[N];

int stack[2*N];
int top;

void quick_sort(int l, int r) {
    stack[top++] = l;
    stack[top++] = r;
    
    int i, j, x;
    while(top > 0) {
        r = stack[--top]; l = stack[--top];
        i = l - 1, j = r + 1;
        x = arr[l+((r-l)>>1)];
        while(i < j) {
            do { ++i; } while(arr[i] < x);
            do { --j; } while(arr[j] > x);
            if(i < j)
                std::swap(arr[i], arr[j]);
        }
        
        if(l < j) {
            stack[top++] = l;
            stack[top++] = j;
        }
        
        if(j + 1 < r) {
            stack[top++] = j + 1;
            stack[top++] = r;
        }
    }
}

int main() {
    int n;
    std::cin >> n;
    
    for(int i = 0; i < n; ++i)
        std::cin >> arr[i];
        
    quick_sort(0, n - 1);
    
    for(int i = 0; i < n; ++i)
        std::cout << arr[i] << " ";
    
    std::cout << std::endl;
    
    return 0;
}

非递归实现

解法四

基于颜色分类问题的快速排序

#include <iostream>
#include <algorithm>

const int N = 100010;

int arr[N];

void quick_sort(int l, int r)
{
    if(l >= r) return;
    
    int idx = l + ((r - l) >> 1);
    int x = arr[idx];
    std::swap(arr[idx], arr[r]);
    int i = l - 1, cur = l;
    while(cur < r) {
        if(arr[cur] < x) {
            std::swap(arr[++i], arr[cur++]);
        } else {
            ++cur;
        }
    }
    
    std::swap(arr[i+1], arr[r]);
    
    quick_sort(l, i);
    quick_sort(i+2, r);
}

int main()
{
    int n;
    std::cin >> n;
    
    for(int i = 0; i < n; ++i)
        std::cin >> arr[i];
    
    quick_sort(0, n - 1);
    
    for(int i = 0; i < n; ++i)
        std::cout << arr[i] << " ";
    
    std::cout << std::endl;
}

复杂度

时间复杂度:O(nlogn)

空间复杂度:O(1)