寻找第k小元素
n个已排序的序列A[1...n]的中项是这个序列的第[n/2](向上取整)小的元素。最直接的方法是对这个序列进行排序并取出该元素,这个方法需要O(n log n)的时间。
选择算法是找出序列中的第k小的元素,该算法会设置一个阈值,当元素个数小于该值时直接排序找出第k小元素。若不小于阈值,则将n个元素分为[n/5]组,每组5个元素,如果n不是5的倍数,则排除剩余的元素。每组进行排序并取出它们的中项即第3个元素。接着将这些中项序列中的中项元素记为mm,它是通过递归计算得到的。将A中的元素划分成三个数组:A1、A2和A3,其中分别包含小于、等于和大于mm的元素。最后求出第k小的元素出现在三个数组中的哪一个,并根据测试结果,算法或者返回滴k小的元素,或者在A1或A3上递归。
算法:SELECT
输入:n个元素的数组A[1...n]和整数k,1<=k<=n
输出:A中的第k小元素
select(A, low, high, k)
?
p ← high - low + 1if p < 44 then 将A排序 return A[k]令q=[p/5](向下取整)。将A分成q组,每组5个元素。如果5不整除p,则排除剩余元素将q组中的每一组单独排序,找出中项。所有中项的集合为Mmm ← select(M, 1, q, [q/2](向上取整)) {mm为中项集合的中项}将A[low...high]分成三组A1 = {a|a<mm}A2 = {a|a=mm}A3 = {a|a>mm}case |A1|>=k: return select(A1, 1, |A1|, k) |A1|+|A2|>=k: return mm |A1|+|A2|<k: return select(A3, 1, |A3|, k-|A1|-|A3|)end case?
?
下面是C++实现:
?
#include <iostream>#include <stack>#include <cmath>using std::stack;using std::cout;using std::endl;int Split(int * a, int low, int high) {int i = low;int x = a[low];for (int j = low+1; j <= high; j++) {if (a[j] <= x) {i ++;if (i != j) {int temp = a[i];a[i] = a[j];a[j] = temp;}}}int temp = a[low];a[low] = a[i];a[i] = temp;return i;}void QuickSort(int * a, int low, int high) {if (low >= high) {return;}stack<int> range;range.push(low);range.push(high);while(!range.empty()) {high = range.top();range.pop();low = range.top();range.pop();int w = Split(a, low, high);if (low < w-1) {range.push(low);range.push(w-1);}if (high > w+1) {range.push(w+1);range.push(high);}}}//寻找第k小的元素,但会破坏原数组的顺序int select(int * A, int low, int high, int k) {int result = 0;int p = high-low+1;if (p < 6/*44*/) {QuickSort(A, low, high);return A[k-1];}int q = p / 5;int * M = new int [q];for (int i = 0; i < q; i++) {QuickSort(A, i*5, i*5+4);M[i] = A[i*5+2];}int mm = select(M, 0, q-1, int(ceil(q/2.0)));int * A1 = new int [p];int * A2 = new int [p];int * A3 = new int [p];int count1 = 0, count2 = 0, count3 = 0;for (int i = low; i <= high; i++) {if (A[i] < mm) {A1[count1++] = A[i];} else if (A[i] == mm) {A2[count2++] = A[i];} else {A3[count3++] = A[i];}}if (count1 >= k) {result = select(A1, 0, count1-1, k);} else if (count1+count2 >= k) {result = mm;} else if (count1+count2 < k) {result = select(A3, 0, count3-1, k-count1-count2);}delete [] M;delete [] A1;delete [] A2;delete [] A3;return result;}int main(void) {int a[] = {8, 33, 17, 51, 57, 49, 35, 11, 25, 37, 14, 3, 2, 13, 52, 12, 6, 29, 32, 54, 5, 16, 22, 23, 7};int result = select(a, 0, 24, 13);cout << "序列:\n";for (int i = 0; i < 25; i++) {cout << a[i] << " ";}cout << endl;cout << "的第k小元素为:" << result << endl;getchar();return 0;}?下面是Java版本:
?
package select;import java.util.ArrayList;import java.util.Arrays;import sort.QuickSort;public class SelectArray {private ArrayList<Integer> array = new ArrayList<Integer>();public SelectArray(int [] array) {this.array.clear();for (int i = 0; i < array.length; i++) {this.array.add(array[i]);}}private int select(int [] A, int low, int high, int k) {//QuickSort qs = null;int result = 0; int p = high-low+1; if (p < 6/*44*/) { A = new QuickSort(A).getSortedIntArray(); return A[k-1]; } int q = p / 5; int [] M = new int [q]; for (int i = 0; i < q; i++) { int [] t = Arrays.copyOfRange(A, i*5, i*5+4); t = new QuickSort(t).getSortedIntArray(); M[i] = t[2]; } int mm = select(M, 0, q-1, (int)Math.floor(q/2.0)); int [] A1 = new int [p]; int [] A2 = new int [p]; int [] A3 = new int [p]; int count1 = 0, count2 = 0, count3 = 0; for (int i = low; i <= high; i++) { if (A[i] < mm) { A1[count1++] = A[i]; } else if (A[i] == mm) { A2[count2++] = A[i]; } else { A3[count3++] = A[i]; } } if (count1 >= k) { result = select(A1, 0, count1-1, k); } else if (count1+count2 >= k) { result = mm; } else if (count1+count2 < k) { result = select(A3, 0, count3-1, k-count1-count2); }return result; }public int getSelectedElement(int k) {int [] A = new int [this.array.size()];for (int i = 0; i < A.length; i++) {A[i] = this.array.get(i); }return select(A, 0, A.length-1, k);}/** * @param args */public static void main(String[] args) {// TODO Auto-generated method stubint a[] = {8, 33, 17, 51, 57, 49, 35, 11, 25, 37, 14, 3, 2, 13, 52, 12, 6, 29, 32, 54, 5, 16, 22, 23, 7};SelectArray sa = new SelectArray(a);System.out.println("序列:");for (int i = 0; i < 25; i++) { System.out.print(a[i] + " "); } System.out.println();System.out.println("的第k小元素为:" + sa.getSelectedElement(13));}}?Python版本如下:
?
#! /usr/bin/env python # -*- coding:utf-8 -*-from math import ceilclass SelectList: def __init__(self, l): self.array = list() for i in l: self.array.append(i) def select(self, a, low, high, k): result = 0 p = high-low + 1 if p < 6: a.sort() return a[k-1] q = p/5 M = [0] * q for i in range(0, q): t = a[i*5:i*5+5] t.sort() M[i] = t[2] mm = self.select(M, 0, q-1, int(ceil(q/2.0))) a1 = [] a2 = [] a3 = [] count1 = 0 count2 = 0 count3 = 0 for i in a: if i < mm: a1.append(i) count1 += 1 elif i == mm: a2.append(i) count2 += 1 else: a3.append(i) count3 += 1 if count1 >= k: result = self.select(a1, 0, count1-1, k) elif count1+count2 >= k: result = mm elif count1+count2 < k: result = self.select(a3, 0, count3-1, k-count1-count2) return result def getSelectedElement(self, k): return self.select(self.array, 0, len(self.array)-1, k)if __name__ == '__main__': a = [8, 33, 17, 51, 57, 49, 35, 11, 25, 37, 14, 3, 2, 13, 52, 12, 6, 29, 32, 54, 5, 16, 22, 23, 7] sl = SelectList(a) print "序列:" for i in a: print i, print print "的第k小元素为:", sl.getSelectedElement(13)?
?