问题
I see a lot of SO topics on related topics but none of them provides the efficient way.
I want to find the k-th smallest element (or median) on 2D array [1..M][1..N] where each row is sorted in ascending order and all elements are distinct.
I think there is O(M log MN) solution, but I have no idea about implementation. (Median of Medians or Using Partition with Linear Complexity is some method but no idea any more...).
This is an old Google interview question and can be searched on Here.
But now I want hint or describe the most efficient algorithm (the fastest one).
Also I read a paper on here but I don't understand it.
Update 1: one solution is found here but when dimension is odd.
回答1:
So to solve this problem, it helps to solve a slightly different one. We want to know the upper/lower bounds in each row for where the overall k'th cutoff is. Then we can go through, verify that the number of things at or below the lower bounds is < k, the number of things at or below the upper bounds is > k, and there is only one value between them.
I've come up with a strategy for doing a binary search in all rows simultaneously for those bounds. Being a binary search it "should" take O(log(n)) passes. Each pass involves O(m) work for a total of O(m log(n)) times. I put should in quotes because I don't have a proof that it actually takes O(log(n)) passes. In fact it is possible to be too aggressive in a row, discover from other rows that the pivot chosen was off, and then have to back off. But I believe that it does very little backing off and actually is O(m log(n)).
The strategy is to keep track in each row of a lower bound, an upper bound, and a mid. Each pass we make a weighted series of ranges to lower, lower to mid, mid to upper, and upper to the end with the weight being the number of things in it and the value being the last in the series. We then find the k'th value (by weight) in that data structure, and use that as a pivot for our binary search in each dimension.
If a pivot winds up out of the range from lower to upper, we correct by widening the interval in the direction that corrects the error.
When we have the correct sequence, we've got an answer.
There are a lot of edge cases, so staring at full code may help.
I also assume that all elements of each row are distinct. If they are not, you can get into endless loops. (Solving that means even more edge cases...)
import random
# This takes (k, [(value1, weight1), (value2, weight2), ...])
def weighted_kth (k, pairs):
# This does quickselect for average O(len(pairs)).
# Median of medians is deterministically the same, but a bit slower
pivot = pairs[int(random.random() * len(pairs))][0]
# Which side of our answer is the pivot on?
weight_under_pivot = 0
pivot_weight = 0
for value, weight in pairs:
if value < pivot:
weight_under_pivot += weight
elif value == pivot:
pivot_weight += weight
if weight_under_pivot + pivot_weight < k:
filtered_pairs = []
for pair in pairs:
if pivot < pair[0]:
filtered_pairs.append(pair)
return weighted_kth (k - weight_under_pivot - pivot_weight, filtered_pairs)
elif k <= weight_under_pivot:
filtered_pairs = []
for pair in pairs:
if pair[0] < pivot:
filtered_pairs.append(pair)
return weighted_kth (k, filtered_pairs)
else:
return pivot
# This takes (k, [[...], [...], ...])
def kth_in_row_sorted_matrix (k, matrix):
# The strategy is to discover the k'th value, and also discover where
# that would be in each row.
#
# For each row we will track what we think the lower and upper bounds
# are on where it is. Those bounds start as the start and end and
# will do a binary search.
#
# In each pass we will break each row into ranges from start to lower,
# lower to mid, mid to upper, and upper to end. Some ranges may be
# empty. We will then create a weighted list of ranges with the weight
# being the length, and the value being the end of the list. We find
# where the k'th spot is in that list, and use that approximate value
# to refine each range. (There is a chance that a range is wrong, and
# we will have to deal with that.)
#
# We finish when all of the uppers are above our k, all the lowers
# one are below, and the upper/lower gap is more than 1 only when our
# k'th element is in the middle.
# Our data structure is simply [row, lower, upper, bound] for each row.
data = [[row, 0, min(k, len(row)-1), min(k, len(row)-1)] for row in matrix]
is_search = True
while is_search:
pairs = []
for row, lower, upper, bound in data:
# Literal edge cases
if 0 == upper:
pairs.append((row[upper], 1))
if upper < bound:
pairs.append((row[bound], bound - upper))
elif lower == bound:
pairs.append((row[lower], lower + 1))
elif lower + 1 == upper: # No mid.
pairs.append((row[lower], lower + 1))
pairs.append((row[upper], 1))
if upper < bound:
pairs.append((row[bound], bound - upper))
else:
mid = (upper + lower) // 2
pairs.append((row[lower], lower + 1))
pairs.append((row[mid], mid - lower))
pairs.append((row[upper], upper - mid))
if upper < bound:
pairs.append((row[bound], bound - upper))
pivot = weighted_kth(k, pairs)
# Now that we have our pivot, we try to adjust our parameters.
# If any adjusts we continue our search.
is_search = False
new_data = []
for row, lower, upper, bound in data:
# First cases where our bounds weren't bounds for our pivot.
# We rebase the interval and either double the range.
# - double the size of the range
# - go halfway to the edge
if 0 < lower and pivot <= row[lower]:
is_search = True
if pivot == row[lower]:
new_data.append((row, lower-1, min(lower+1, bound), bound))
elif upper <= lower:
new_data.append((row, lower-1, lower, bound))
else:
new_data.append((row, max(lower // 2, lower - 2*(upper - lower)), lower, bound))
elif upper < bound and row[upper] <= pivot:
is_search = True
if pivot == row[upper]:
new_data.append((row, upper-1, upper+1, bound))
elif lower < upper:
new_data.append((row, upper, min((upper+bound+1)//2, upper + 2*(upper - lower)), bound))
else:
new_data.append((row, upper, upper+1, bound))
elif lower + 1 < upper:
if upper == lower+2 and pivot == row[lower+1]:
new_data.append((row, lower, upper, bound)) # Looks like we found the pivot.
else:
# We will split this interval.
is_search = True
mid = (upper + lower) // 2
if row[mid] < pivot:
new_data.append((row, mid, upper, bound))
elif pivot < row[mid] pivot:
new_data.append((row, lower, mid, bound))
else:
# We center our interval on the pivot
new_data.append((row, (lower+mid)//2, (mid+upper+1)//2, bound))
else:
# We look like we found where the pivot would be in this row.
new_data.append((row, lower, upper, bound))
data = new_data # And set up the next search
return pivot
回答2:
Another answer has been added to provide an actual solution. This one has been left as it was due to quite the rabbit hole in the comments.
I believe the fastest solution for this is the k-way merge algorithm. It is a O(N log K) algorithm to merge K sorted lists with a total of N items into a single sorted list of size N.
https://en.wikipedia.org/wiki/K-way_merge_algorithm#k-way_merge
Given a MxN list. This ends up being O(MNlog(M)). However, that is for sorting the entire list. Since you only need the first K smallest items instead of all N*M, the performance is O(Klog(M)). This is quite a bit better than what you are looking for, assuming O(K) <= O(M).
Though this assumes you have N sorted lists of size M. If you actually have M sorted lists of size N, this can be easily handled though just by changing how you loop over the data (see the pseudocode below), though it does mean the performance is O(K log(N)) instead.
A k-way merge just adds the first item of each list to a heap or other data structure with a O(log N) insert and O(log N) find-mind.
Pseudocode for k-way merge looks a bit like this:
- For each sorted list, insert the first value into the data structure with some means of determining which list the value came from. IE: You might insert
[value, row_index, col_index]into the data structure instead of justvalue. This also lets you easily handle looping over either columns or rows. - Remove the lowest value from the data structure and append to the sorted list.
- Given that the item in step #2 came from list
Iadd the next lowest value from listIto the data structure. IE: if value wasrow 5 col 4 (data[5][4]). Then if you are using rows as lists, then the next value would berow 5 col 5 (data[5][5]). If you are using columns then the next value isrow 6 col 4 (data[6][4]). Insert this next value into the data structure like you did #1 (ie:[value, row_index, col_index]) - Go back to step 2 as needed.
For your needs, do steps 2-4 K times.
回答3:
Seems like the best way to go is a k-way merge in increasingly larger sized blocks. A k-way merge seeks to build a sorted list, but we don't need it sorted and we don't need to consider each element. Instead we'll create a semi-sorted intervals. The intervals will be sorted, but only on the highest value.
https://en.wikipedia.org/wiki/K-way_merge_algorithm#k-way_merge
We use the same approach as a k-way merge, but with a twist. Basically it aims to indirectly build a semi-sorted sublist. For example instead of finding [1,2,3,4,5,6,7,8,10] to determine the K=10, it will instead find something like [(1,3),(4,6),(7,15)]. With K-way merge we consider 1 item at a time from each list. In this approach hover, when pulling from a given list, we want to first consider Z items, then 2 * Z items, then 2 * 2 * Z items, so 2^i * Z items for the i-th time. Given an MxN matrix that means it will require we pull up to O(log(N)) items from the list M times.
- For each sorted list, insert the first
Ksublists into the data structure with some means of determining which list the value came from. We want the data structure to use the highest value in the sublist we insert into it. In this case we would want something like [max_value of sublist, row index, start_index, end_index].O(m) - Remove the lowest value (this is now a list of values) from the data structure and append to the sorted list.
O(log (m)) - Given that the item in step #2 came from list
Iadd the next2^i * Zvalues from listIto the data structure upon the i-th time pulling from that specific list (basically just double the number that was present in the sublist just removed from the data structure).O(log m) - If the size of the semi-sorted sublist is greater than K, use binary search to find the kth value.
O(log N)). If there are any sublists remaining in the data structure, where the min value is less than k. Goto step 1 with the lists as inputs and the newKbeingk - (size of semi-sorted list). - If the size of the semi-sorted sublist is equal to K, return the last value in the semi-sorted sublist, this is the Kth value.
- If the size of the semi-sorted sublist is less than K, go back to step 2.
As for performance. Let's see here:
- Takes
O(m log m)to add the initial values to the data structure. - It needs to consider at most
O(m)sublists each requiringO(log n)time for `O(m log n). - It needs perform a binary search at the end,
O(log m), it may need to reduce the problem into a recursive sublists if there is uncertainty about what the value of K is (Step 4), but I don't think that'll affect the big O. Edit: I believe this just adds anotherO(mlog(n))in the worst case, which has no affect on the Big O.
So looks like it's O(mlog(m) + mlog(n)) or simply O(mlog(mn)).
As an optimization, if K is above NM/2 consider the max value when you consider the min value and the min value when you would consider the max value. This will greatly increase the performance when K is close to NM.
回答4:
The answers by btilly and Nuclearman provide two different approaches, a kind of binary search and a k-way merge of the rows.
My proposal is to combine both methods.
If k is small (let's say less than M times 2 or 3) or big (for simmetry, close to N x M) enough, find the kth element with a M-way merge of the rows. Of course, we shouldn't merge all the elements, just the first k.
Otherwise, start inspecting the first and the last column of the matrix in order to find the minimum (witch is in the first column) and the maximum (in the last column) values.
Estimate a first pivotal value as a linear combination of those two values. Something like
pivot = min + k * (max - min) / (N * M).Perform a binary search in each row to determine the last element (the closer) not greater than the pivot. The number of elements less than or equal to the pivot is simply deduced. Comparing the sum of those with k will tell if the chosen pivot value is too big or too small and let us modify it accordingly. Keep track of the maximum value between all the rows, it may be the kth-element or just used to evaluate the next pivot. If we consider said sum as a function of the pivot, the numeric problem is now to find the zero of
sum(pivot) - k, which is a monotonic (discrete) function. At worst, we can use the bisection method (logarithmic complexity) or the secant method.We can ideally partition each row in three ranges:
- At the left, the elements whitch are surely less than or equal to the kth element.
- In the middle, the undeterminated range.
- At the right, the elements whitch are surely greater than the kth element.
The undeterminate range will reduce at every iteration, eventually becoming empty for most rows. At some point, the number of elements still in the undeterminated ranges, scattered throughout the matrix, will be small enough to resort to a single M-way merge of those ranges.
If we consider the time complexity of a single iteration as
O(MlogN), or M binary searches, we need to multiply it by the number of iterations required for the pivot to converge to the value of the kth-element, which could beO(logNM). This sum up toO(MlogNlogM)orO(MlogNlogN), if N > M.Note that, if the algorithm is used to find the median, with the M-way merge as last step is easy to find the (k + 1)th-element too.
回答5:
May be I am missing something but If your NxM matrix A have M rows are already sorted ascending with no repetition of elements then k-th smallest value of row is just picking k-th element from row which is O(1). To move to 2D you just select the k-th column instead, sort it ascending O(M.log(M)) and again pick k-th element leading to O(N.log(N)).
lets have matrix
A[N][M]where elements are
A[column][row]sort
k-thcolumn ofAascendingO(M.log(M))so sort
A[k][i]wherei = { 1,2,3,...M }ascendingpick
A[k][k]as the result
In case you want k-th smallest of all the elements in A instead then You need to exploit the already sorted rows in form similar to merge sort.
create empty list
c[]for holdingksmallest valuesprocess columns
create temp array
b[]which holds the processed column quick sorted ascending
O(N.log(N))merge
c[]andb[]soc[]holds up toksmallest valuesUsing temp array
d[]will lead toO(k+n)if during merging was not used any item from
bstop processing columnsThis can be done by adding flag array
fwhich will hold where fromb,cthe value was taken during the merge and then just checking if any value was taken fromboutput
c[k-1]
When put all together the final complexity is O(min(k,M).N.log(N)) if we consider that k is less than M we can rewrite to O(k.N.log(N)) otherwise O(M.N.log(N)). Also on average the number of columns to iterate will be even less more likely ~(1+(k/N)) so average complexity would be ~O(N.log(N)) but that is just my wild guess which might be wrong.
Here small C++/VCL example:
//$$---- Form CPP ----
//---------------------------------------------------------------------------
#include <vcl.h>
#pragma hdrstop
#include "Unit1.h"
#include "sorts.h"
//---------------------------------------------------------------------------
#pragma package(smart_init)
#pragma resource "*.dfm"
TForm1 *Form1;
//---------------------------------------------------------------------------
const int m=10,n=8; int a[m][n],a0[m][n]; // a[col][row]
//---------------------------------------------------------------------------
void generate()
{
int i,j,k,ii,jj,d=13,b[m];
Randomize();
RandSeed=0x12345678;
// a,a0 = some distinct pseudorandom values (fully ordered asc)
for (k=Random(d),j=0;j<n;j++)
for (i=0;i<m;i++,k+=Random(d)+1)
{ a0[i][j]=k; a[i][j]=k; }
// schuffle a
for (j=0;j<n;j++)
for (i=0;i<m;i++)
{
ii=Random(m);
jj=Random(n);
k=a[i][j]; a[i][j]=a[ii][jj]; a[ii][jj]=k;
}
// sort rows asc
for (j=0;j<n;j++)
{
for (i=0;i<m;i++) b[i]=a[i][j];
sort_asc_quick(b,m);
for (i=0;i<m;i++) a[i][j]=b[i];
}
}
//---------------------------------------------------------------------------
int kmin(int k) // k-th min from a[m][n] where a rows are already sorted
{
int i,j,bi,ci,di,b[n],*c,*d,*e,*f,cn;
c=new int[k+k+k]; d=c+k; f=d+k;
// handle edge cases
if (m<1) return -1;
if (k>m*n) return -1;
if (m==1) return a[0][k];
// process columns
for (cn=0,i=0;i<m;i++)
{
// b[] = sorted_asc a[i][]
for (j=0;j<n;j++) b[j]=a[i][j]; // O(n)
sort_asc_quick(b,n); // O(n.log(n))
// c[] = c[] + b[] asc sorted and limited to cn size
for (bi=0,ci=0,di=0;;) // O(k+n)
{
if ((ci>=cn)&&(bi>=n)) break;
else if (ci>=cn) { d[di]=b[bi]; f[di]=1; bi++; di++; }
else if (bi>= n) { d[di]=c[ci]; f[di]=0; ci++; di++; }
else if (b[bi]<c[ci]){ d[di]=b[bi]; f[di]=1; bi++; di++; }
else { d[di]=c[ci]; f[di]=0; ci++; di++; }
if (di>k) di=k;
}
e=c; c=d; d=e; cn=di;
for (ci=0,j=0;j<cn;j++) ci|=f[j]; // O(k)
if (!ci) break;
}
k=c[k-1];
delete[] c;
return k;
}
//---------------------------------------------------------------------------
__fastcall TForm1::TForm1(TComponent* Owner):TForm(Owner)
{
int i,j,k;
AnsiString txt="";
generate();
txt+="a0[][]\r\n";
for (j=0;j<n;j++,txt+="\r\n")
for (i=0;i<m;i++) txt+=AnsiString().sprintf("%4i ",a0[i][j]);
txt+="\r\na[][]\r\n";
for (j=0;j<n;j++,txt+="\r\n")
for (i=0;i<m;i++) txt+=AnsiString().sprintf("%4i ",a[i][j]);
k=20;
txt+=AnsiString().sprintf("\r\n%ith smallest from a0 = %4i\r\n",k,a0[(k-1)%m][(k-1)/m]);
txt+=AnsiString().sprintf("\r\n%ith smallest from a = %4i\r\n",k,kmin(k));
mm_log->Lines->Add(txt);
}
//-------------------------------------------------------------------------
Just ignore the VCL stuff. Function generate computes a0, a matrices where a0 is fully sorted and a has only rows sorted and all values are distinct. The function kmin is the algo described above returning k-th smallest value from a[m][n] For sorting I used this:
template <class T> void sort_asc_quick(T *a,int n)
{
int i,j; T a0,a1,p;
if (n<=1) return; // stop recursion
if (n==2) // edge case
{
a0=a[0];
a1=a[1];
if (a0>a1) { a[0]=a1; a[1]=a0; } // condition
return;
}
for (a0=a1=a[0],i=0;i<n;i++) // pivot = midle (should be median)
{
p=a[i];
if (a0>p) a0=p;
if (a1<p) a1=p;
} if (a0==a1) return; p=(a0+a1+1)/2; // if the same values stop
if (a0==p) p++;
for (i=0,j=n-1;i<=j;) // regroup
{
a0=a[i];
if (a0<p) i++; else { a[i]=a[j]; a[j]=a0; j--; }// condition
}
sort_asc_quick(a , i); // recursion a[]<=p
sort_asc_quick(a+i,n-i); // recursion a[]> p
}
And Here the output:
a0[][]
10 17 29 42 54 66 74 85 90 102
112 114 123 129 142 145 146 150 157 161
166 176 184 191 195 205 213 216 222 224
226 237 245 252 264 273 285 290 291 296
309 317 327 334 336 349 361 370 381 390
397 398 401 411 422 426 435 446 452 462
466 477 484 496 505 515 522 524 525 530
542 545 548 553 555 560 563 576 588 590
a[][]
114 142 176 264 285 317 327 422 435 466
166 336 349 381 452 477 515 530 542 553
157 184 252 273 291 334 446 524 545 563
17 145 150 237 245 290 370 397 484 576
42 129 195 205 216 309 398 411 505 560
10 102 123 213 222 224 226 390 496 555
29 74 85 146 191 361 426 462 525 590
54 66 90 112 161 296 401 522 548 588
20th smallest from a0 = 161
20th smallest from a = 161
This example iterated only 5 columns...
来源:https://stackoverflow.com/questions/64916213/fastest-algorithm-for-kth-smallest-element-or-median-finding-on-2-dimensional