数组中的第 K 个最大元素

题目要求

给定一个整数数组 nums 和一个整数 k,要求编写一个算法来找出并返回这个数组中第 k 个最大的元素。这里的第 k 个最大元素是指在数组完全排序后处于倒数第 k 个位置的元素。

需要注意的是,题目要求返回的是第 k 个最大的元素,而不是第 k 个不同的元素。此外,题目强调算法的时间复杂度必须为 O(n),这意味着不能直接对数组进行排序,因为常规的排序算法如快速排序、归并排序的平均时间复杂度为 O(n log n)。

解题思路

要在 O(n) 的时间复杂度内解决这个问题,可以采用以下几种思路:

  1. 快速选择算法(Quick Select):这是快速排序算法的变种,用于在未完全排序的数组中查找第 k 个最小(或最大)元素。算法的基本思想是随机选择一个“枢轴”元素,然后将数组分为两部分:一部分包含小于枢轴的元素,另一部分包含大于枢轴的元素。这时可以确定枢轴元素的确切位置,如果这个位置恰好是我们要找的第 k 个最大元素的位置,我们就找到了答案。如果不是,我们可以递归地在较小或较大的部分中继续查找。由于每次可以排除一半的元素,所以平均时间复杂度为 O(n)。

  2. 堆(Heap):可以使用最小堆来解决这个问题。首先创建一个大小为 k 的最小堆,并将数组 nums 的前 k 个元素添加到堆中。然后遍历数组中剩余的元素,对于每个元素,如果它大于堆顶元素,则将堆顶元素移除并将当前元素添加到堆中。遍历完成后,堆顶元素即为第 k 个最大的元素。虽然堆的插入和删除操作的时间复杂度是 O(log k),但由于我们只维护一个大小为 k 的堆,并且每个元素最多只被插入和删除一次,因此总的时间复杂度仍然是 O(n)。

  3. 中位数的中位数算法(Median of Medians):这是快速选择算法的一个优化版本,它选择一个好的枢轴来保证算法的最坏情况时间复杂度为 O(n)。它通过选择一组元素的中位数作为枢轴来改进快速选择算法,这样可以保证每次至少有一定比例的元素被排除,从而避免了快速选择算法在最坏情况下退化到 O(n^2) 的风险。

在实际应用中,快速选择算法由于其实现简单和在平均情况下的高效性,通常是解决这类问题的首选方法。而中位数的中位数算法虽然可以保证最坏情况下的性能,但由于其常数因子较大,在实际中使用较少。

Golang 版本

package main

import (
	"fmt"
	"math/rand"
	"time"
)

func findKthLargest(nums []int, k int) int {
	target := len(nums) - k
	left, right := 0, len(nums)-1
	rand.Seed(time.Now().UnixNano())

	for left <= right {
		pivotIndex := rand.Intn(right-left+1) + left
		newPivotIndex := partition(nums, left, right, pivotIndex)

		if newPivotIndex == target {
			return nums[newPivotIndex]
		} else if newPivotIndex < target {
			left = newPivotIndex + 1
		} else {
			right = newPivotIndex - 1
		}
	}

	return -1
}

func partition(nums []int, left, right, pivotIndex int) int {
	pivotValue := nums[pivotIndex]
	nums[pivotIndex], nums[right] = nums[right], nums[pivotIndex]
	storeIndex := left

	for i := left; i < right; i++ {
		if nums[i] < pivotValue {
			nums[storeIndex], nums[i] = nums[i], nums[storeIndex]
			storeIndex++
		}
	}

	nums[right], nums[storeIndex] = nums[storeIndex], nums[right]
	return storeIndex
}

func main() {
	nums := []int{3, 2, 1, 5, 6, 4}
	k := 2
	fmt.Printf("The %dth largest element is %d\n", k, findKthLargest(nums, k))
}

这段代码实现了快速选择算法,用于在 O(n) 的平均时间复杂度内找到数组中第 k 个最大的元素。代码首先定义了 findKthLargest 函数,它使用了一个辅助函数 partition 来根据枢轴值对数组进行划分,并递归地在划分的子数组中查找第 k 个最大的元素。主函数 main 提供了一个示例数组和 k 值,并调用 findKthLargest 函数来输出结果。

C++版本

#include <iostream>
#include <vector>
#include <algorithm>
#include <cstdlib>
#include <ctime>

int partition(std::vector<int>& nums, int left, int right) {
    int pivot = nums[left];
    int l = left + 1;
    int r = right;
    while (l <= r) {
        if (nums[l] < pivot && nums[r] > pivot) {
            std::swap(nums[l++], nums[r--]);
        }
        if (nums[l] >= pivot) l++;
        if (nums[r] <= pivot) r--;
    }
    std::swap(nums[left], nums[r]);
    return r;
}

int quickSelect(std::vector<int>& nums, int left, int right, int k) {
    if (left == right) return nums[left];
    srand(time(0));
    int pivotIndex = left + rand() % (right - left + 1);
    pivotIndex = partition(nums, left, right);
    if (k == pivotIndex) {
        return nums[k];
    } else if (k < pivotIndex) {
        return quickSelect(nums, left, pivotIndex - 1, k);
    } else {
        return quickSelect(nums, pivotIndex + 1, right, k);
    }
}

int findKthLargest(std::vector<int>& nums, int k) {
    int size = nums.size();
    return quickSelect(nums, 0, size - 1, size - k);
}

int main() {
    std::vector<int> nums = {3, 2, 1, 5, 6, 4};
    int k = 2;
    std::cout << "The " << k << "th largest element is " << findKthLargest(nums, k) << std::endl;
    return 0;
}

这段 C++ 代码实现了快速选择算法,用于在 O(n) 的平均时间复杂度内找到数组中第 k 个最大的元素。代码中定义了 partition 函数用于根据枢轴值对数组进行划分,quickSelect 函数用于递归地在划分的子数组中查找第 k 个最大的元素,以及 findKthLargest 函数用于启动快速选择过程。主函数 main 提供了一个示例数组和 k 值,并调用 findKthLargest 函数来输出结果。

Python 版本

import random

def partition(nums, left, right):
    pivot = nums[right]
    i = left
    for j in range(left, right):
        if nums[j] < pivot:
            nums[i], nums[j] = nums[j], nums[i]
            i += 1
    nums[i], nums[right] = nums[right], nums[i]
    return i

def quick_select(nums, left, right, k):
    if left == right:
        return nums[left]
    pivot_index = random.randint(left, right)
    pivot_index = partition(nums, left, right)
    if k == pivot_index:
        return nums[k]
    elif k < pivot_index:
        return quick_select(nums, left, pivot_index - 1, k)
    else:
        return quick_select(nums, pivot_index + 1, right, k)

def find_kth_largest(nums, k):
    return quick_select(nums, 0, len(nums) - 1, len(nums) - k)

# Example usage:
if __name__ == "__main__":
    nums = [3, 2, 1, 5, 6, 4]
    k = 2
    print(f"The {k}th largest element is {find_kth_largest(nums, k)}")

这段 Python 代码实现了快速选择算法,用于在 O(n) 的平均时间复杂度内找到数组中第 k 个最大的元素。代码中定义了 partition 函数用于根据枢轴值对数组进行划分,quick_select 函数用于递归地在划分的子数组中查找第 k 个最大的元素,以及 find_kth_largest 函数用于启动快速选择过程。主函数部分提供了一个示例数组和 k 值,并调用 find_kth_largest 函数来输出结果。

Java 版本

import java.util.Random;

public class KthLargestElement {

    public static int findKthLargest(int[] nums, int k) {
        int n = nums.length;
        k = n - k; // convert to kth smallest
        int left = 0;
        int right = n - 1;
        Random random = new Random();

        while (left <= right) {
            int pivotIndex = partition(nums, left, right, random.nextInt(right - left + 1) + left);
            if (pivotIndex < k) {
                left = pivotIndex + 1;
            } else if (pivotIndex > k) {
                right = pivotIndex - 1;
            } else {
                return nums[k];
            }
        }

        throw new IllegalArgumentException("k is out of bounds");
    }

    private static int partition(int[] nums, int left, int right, int pivotIndex) {
        int pivotValue = nums[pivotIndex];
        swap(nums, pivotIndex, right); // Move pivot to end
        int storeIndex = left;

        for (int i = left; i < right; i++) {
            if (nums[i] < pivotValue) {
                swap(nums, storeIndex, i);
                storeIndex++;
            }
        }

        swap(nums, right, storeIndex); // Move pivot to its final place
        return storeIndex;
    }

    private static void swap(int[] nums, int a, int b) {
        int tmp = nums[a];
        nums[a] = nums[b];
        nums[b] = tmp;
    }

    public static void main(String[] args) {
        int[] nums = {3, 2, 1, 5, 6, 4};
        int k = 2;
        System.out.println("The " + k + "th largest element is " + findKthLargest(nums, k));
    }
}

这段 Java 代码实现了快速选择算法,用于在 O(n) 的平均时间复杂度内找到数组中第 k 个最大的元素。代码中定义了 partition 方法用于根据枢轴值对数组进行划分,findKthLargest 方法用于启动快速选择过程。swap 方法用于在数组中交换两个元素的位置。主方法 main 提供了一个示例数组和 k 值,并调用 findKthLargest 方法来输出结果。

Rust 版本

use rand::Rng;
use std::cmp::Ordering;

fn partition(nums: &mut Vec<i32>, left: usize, right: usize, pivot_index: usize) -> usize {
    nums.swap(pivot_index, right);
    let pivot = nums[right];
    let mut store_index = left;
    for i in left..right {
        if nums[i] < pivot {
            nums.swap(store_index, i);
            store_index += 1;
        }
    }
    nums.swap(right, store_index);
    store_index
}

fn quick_select(nums: &mut Vec<i32>, left: usize, right: usize, k: usize) -> i32 {
    if left == right {
        return nums[left];
    }

    let mut rng = rand::thread_rng();
    let pivot_index = rng.gen_range(left..=right);
    let pivot_index = partition(nums, left, right, pivot_index);

    match pivot_index.cmp(&k) {
        Ordering::Equal => nums[k],
        Ordering::Greater => quick_select(nums, left, pivot_index - 1, k),
        Ordering::Less => quick_select(nums, pivot_index + 1, right, k),
    }
}

pub fn find_kth_largest(nums: &mut Vec<i32>, k: i32) -> i32 {
    let size = nums.len();
    quick_select(nums, 0, size - 1, size - k as usize)
}

fn main() {
    let mut nums = vec![3, 2, 1, 5, 6, 4];
    let k = 2;
    println!("The {}th largest element is {}", k, find_kth_largest(&mut nums, k));
}

这段 Rust 代码实现了快速选择算法,用于在 O(n) 的平均时间复杂度内找到数组中第 k 个最大的元素。代码中定义了 partition 函数用于根据枢轴值对数组进行划分,quick_select 函数用于递归地在划分的子数组中查找第 k 个最大的元素,以及 find_kth_largest 函数用于启动快速选择过程。主函数 main 提供了一个示例数组和 k 值,并调用 find_kth_largest 函数来输出结果。

总结

上面的解法采用了快速选择算法,这是一种基于快速排序的选择算法,用于在未完全排序的数组中查找第 k 个最小(或最大)元素的高效算法。其核心思想是:

  1. 选择枢轴(Pivot):随机选择一个元素作为枢轴。
  2. 分区(Partitioning):重新排列数组,使得所有小于枢轴的元素都在其左侧,而所有大于枢轴的元素都在其右侧。枢轴的最终位置就是它如果数组被排序后应该在的位置。
  3. 递归(Recursion):递归地在枢轴的左侧或右侧子数组中查找第 k 个最小(或最大)元素。
    • 如果枢轴的位置恰好是 k,那么它就是我们要找的元素。
    • 如果 k 小于枢轴的位置,我们只需要在左侧子数组中查找。
    • 如果 k 大于枢轴的位置,我们只需要在右侧子数组中查找。

快速选择算法的平均时间复杂度为 O(n),但最坏情况下会退化到 O(n^2)。通过随机选择枢轴可以减少这种最坏情况发生的概率。

在上面的代码实现中,我们定义了以下函数:

  • partition:用于实现分区逻辑。
  • quick_select:用于递归地在数组中查找第 k 个最小(或最大)元素。
  • find_kth_largest:用于调整 k 的值以适应快速选择算法,并开始查找过程。

最后,main 函数或相应的入口点提供了一个示例数组和 k 值,并调用 find_kth_largest 函数来输出第 k 个最大元素的值。