Post

经典算法题:找到数组中的第K大的数

题目介绍

一道面试时频繁被问及的题目,而且还换着花样问。

给定一个整形数组,找到数组中第 K 大的元素。

还是老样子,光看题目没啥意思,还是思考思考这个算法能有什么实际的用处。其实应用就是把这里的整形数组换成是有意义的数据,比如说是一堆学生的成绩,找出排名第 K 的学生,或者是一个医院的体检报告,找出男性身高第 K 高的。

可能固定数组的应用场景固定,但假如说这里是实时的流数据,应用场景可就多了,比如说实时的线上延迟报表,10 万个请求中第 99995 高的延迟具体是多少时间,我们可以由此算出百分位数(percentile),这对我们分析线上整体的情况很有帮助。那把这里的固定的整形数组换成一个实时的数据流,我们又该如何设计我们的算法呢?

思路一:排序

这个题目最直接了当,也是最容易想到的方法就是排序。因为数组元素固定,直接把整个数组排好序,按序索取即可:

1
2
3
4
public int findKth(int[] arr, int k) {
  Arrays.sort(arr);
  return arr[arr.length - k];
}

这个思路清晰明了,没有任何的理解难度。时间复杂度以及空间复杂度都在排序上,拿 Java 来说,内部排序算法主要是归并排序,因而时间复杂度是 O(NlogN),空间复杂度为 O(N),其中 N 为数组中的元素的总个数。

思路二:堆/优先队列

我们在之前的 文章 中介绍了堆排序算法。用堆来解决这个问题的一个方向就是通过数组来构建一个最大堆,然后依次从堆中移除 K 个元素,第 K 个被移除的就是我们要找的第 K 大元素。这里为了简便,我直接使用 Java 内置的 PriorityQueue,背后的思想都是一样的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public int findKth(int[] arr, int k) {
  PriorityQueue<Integer> pq = new PriorityQueue<Integer>(
    Collections.reverseOrder()
  );

  for (int i = 0; i < arr.length; i++) {
    pq.add(arr[i]);
  }

  for (int i = 1; i < k; i++) {
    pq.poll();
  }

  return pq.peek();
}

这里的复杂度分析就比之前排序的要复杂些,首先通过数组构建堆花费 O(N),具体理论推导证明可以参考 文章。在此基础上,我们进行了 K 次的删除最值操作,每次操作的时间是 O(logN),这样一来,最后的时间复杂度就会是 O(N + KlogN),当 K == N 时,这个复杂度还是会退回到 O(NlogN)。咋一看,搞了半天和之前没啥区别啊?现实情况下是因为 K <= N,当 K 取比较小的值的话,这个算法还是比之前的要快不少。当然了,这并没有达到我们的目的。

再说说空间复杂度,上面因为重新创建了一个数据结构,空间复杂度就会是 O(N)。当然,如果是自己手动实现堆,我们是可以通过一个技巧来节省掉这部分空间,具体可以参照 堆排序的文章

从这里的分析可以看到,堆中的元素越多则每次堆操作就会消耗越多的时间。既然如此,另外一个借助堆的思路就是我们只保存 K 个元素到堆中,然后依次遍历数组中剩下的元素,每当遍历到的元素比当前堆中的最小元素要大我们就将其加入,然后将最小的元素移除,这样一来堆的大小始终是 K 个,遍历完整个数组后堆顶的元素就是第 K 大的元素,具体实现和之前大同小异:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public int findKth(int[] arr, int k) {
  PriorityQueue<Integer> pq = new PriorityQueue<Integer>(
    Collections.reverseOrder()
  );

  for (int i = 0; i < k; i++) {
    pq.add(arr[i]);
  }

  for (int i = k; i < arr.length; i++) {
    if (pq.peek() < arr[i]) {
      pq.poll();
      pq.add(arr[i]);
    }
  }

  return pq.peek();
}

首先构建堆花费 O(K),然后每次移除和插入操作的时间复杂度是 O(logK),进行 2(N-K) 次,最后算在一起,最终的时间复杂度就会是 O(K + (N-K)logK),当 K 取比较极端的值的话,这个时间复杂度还是会退回到 O(NlogN)

虽然说第二种算法和第一种在时间复杂度层面并没有太多的不一样,但是第二种算法是可以应用到实时的流数据上的,因为第二种算法并不需要一开始就拿到全部的数据,堆中存放的一直是 “当前流过的最大的 K 个元素”。这也就解释了我们一开始遗留的问题。

思路三:分治(快速选择)

排序算法中一个比较普遍的思路就是分治,例如家喻户晓的归并排序以及快速排序就是分治的范例。大体思路就是将一个数组一分为二,分别对这两部分进行线性时间的操作。因为每次都是对半分,所以我们可以分 logN 次,每次操作都是线性的,最后的时间复杂度就是 O(NlogN)

看到这里你可能已经有点不耐烦了,搞什么,弄来弄去都还是排序。可是这里我们还是可以有优化的地方,排序是要对整个数组进行操作,但是我们的目的是找到第 K 大的元素,当你把一个数组一分为二,这个我们要找的元素只可能在其中一个半区,另外一个半区我们完全可以丢掉不管,也就是说每次分治过后我们后续的操作范围就缩小为之前的一半。那具体的时间复杂度是多少呢?假设说每次我们都能均分一个数组,第一次均分我们的操作时间是 N,第二次均分后我们的操作时间是 N/2,第三次 N/4 以此类推最后的时间复杂度就会是:

\[O(N + \frac{N}{2} + \frac{N}{4} + \frac{N}{8} + ... + 1) = O(N)\]

具体的实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
public int findKth(int[] arr, int k) {
  return findKth(arr, k, 0, arr.length - 1);
}

public int findKth(int[] arr, int k, int start, int end) {
  if (start >= end) {
    return arr[start];
  }

  int pivot = arr[(start+end)/2];

  int i = start, j = end;
  while (i < j) {
    while (i < j && pivot >= arr[i++]) {}
    while (i < j && pivot <= arr[j--]) {}
    swap(arr, i, j);
  }

  if (k == i) {
    return arr[i];
  } else if (k < i) {
    return findKth(arr, start, i - 1, k);
  }

  return findKth(arr, i, end, k);
}

public void swap(int[] arr, int i, int j) {
  int tmp = arr[i];
  arr[i] = arr[j];
  arr[j] = tmp;
}

上面的实现和快速排序很像,不一样的地方在于最后的时候快速排序是需要对左右两边都进行下一步的操作,而这里我们只是基于我们要找的 K 的值来进行选择。

最终我们可以将时间复杂度优化到 O(N),但需要注意的是这里的时间复杂度是平均时间复杂度,能否达到这个时间复杂度很大程度上取决于均分值(pivot)的选择,如果 pivot 选好的话,能够保证数组能够被均匀的切分成两半,那么算法的效率就会越高。在上面的实现中,我们选择的是数组中间的元素,关于如何选择 pivot,后面我们会在快速排序中详细介绍。

你可能会说,怎么是平均啊?这也不能说明它比前面两个算法更快呀?当然了你完全可以设计一个非常极端的情况让上面的算法退化到 O(N^2)。但在实际的运行过程中,这个算法的速度是大大优于前面两个算法的。这就要说到快速排序了,快速排序快速的地方就在于所有的操作都是基于一个连续的数组,并且在这个过程中并没有创建新的数据结构,但这一部分并没有反应到时间复杂度上,具体的内容后面我们会再做介绍

This post is licensed under CC BY 4.0 by the author.