Skip to content

14286: 【原4286】轮到你了

题目

题目描述

author: Howell 原OJ链接:https://acm.sjtu.edu.cn/OnlineJudge-old/problem/4286 注意!!本题不允许使用std::sort(和STL中其它排序算法),std::nth_element,std::make_heap和std::priority_queue(以及其它STL中的有序结构),否则成绩无效(将会手动检查代码)。因此,即使你只想要拿到部分分数,也请自己手写一个排序。

题目描述

晋红衣听说在中期反馈的时候,不少同学吐槽作业题量太大,因此她决定制裁这些同学。

她偷偷统计了每个学生吐槽的数量,并且准备制裁吐槽特别多的人。为了学生们的利益,小助教们据理力争,让晋红衣最终妥协:如果她认为吐槽数第$k$小的学生的吐槽数量不太大,则她不会制裁任何学生。否则,她会把下次小作业的题量翻倍。

现在小助教们精疲力尽,没法统计吐槽数第$k$小的学生的吐槽数,因此,决定下次小作业有几道题的重任就轮到你了。

输入描述

读入数据的第一行是两个数字$n$和$k$

读入数据的第二行是三个数字$seed,Ta$和$Tb$,你需要利用它们来得到所有同学的吐槽数量,具体方法是:

const int MAXN = 20000010
int a[MAXN];
int seed, Ta, Tb;
//...
int Generate()
{
     return seed = seed * Ta + Tb;
}
//...
int main()
{
    /*
     * Todo
     * you can input like scanf("%d%d%d", &seed, &Ta, &Tb)
     */
    for (int i = 0;i < n;++i) 
        a[i] = Generate();
    /*
     * Todo
     */
}

这几行代码(不包括注释)应当出现在你提交的代码里。

输出描述

输出一个数字,代表吐槽数第$k$小的学生的吐槽数的值

样例输入

5 3
0 1 10

样例输出

30

样例解释

根据输入第二行的内容,5个学生的吐槽数分别为10,20,30,40,50,因此第三小的数字是30.

数据范围

对于$50\%$的数据,$0<k\leq n\leq 1,000,000$;

对于所有的数据,$0<k\leq n\leq 20,000,000$

HINT

为了求得数列中第$k$小的数,我们可以使用的方法有:

1.排序;

2.分治;

3.二分答案

4.$\cdots$

下面简单介绍分治算法的思路:

假设我们想要求出$a[l],a[l+1],\cdots a[r]$之中第$k$的数小,我们可以先选出这段数列中的一个数$a[m]$,然后重新整理这个数列(记为$b[l],b[l+1]\cdots b[r]$),将所有比$a[m]$小的数放在$a[m]$左边(记作$b[l],b[l+1]\cdots b[l+p-1]$),比$a[m]$小的数放在右边(记作$b[l+p+1],b[l+p+2]\cdots b[r]$)而$b[l+p]=a[m]$,代表$a[m]$是这个数列当中第$p+1$小的数。

这一重新整理的操作的复杂度是$O(n)$。

现在,如果$p>k$,我们可以在$b[l+1],b[l+2]\cdots b[l+p-1]$中寻找第$k$小的数;如果$p<k$,则我们可以在$b[l+p],b[l+p+1]\cdots b[r]$中寻找第$k-p$小的数;当$p=k$时,$b[l+p]=a[m]$即为我们想要找的数。

如果我们选取的$a[m]$总是恰好是这个数列的中位数,则下一步中,我们要处理的数列长度恰好是这一次的一半,因此,此时的复杂度为$O(n)+O(n/2)+\cdots =O(2n)=O(n)$。

(以下选读)一个更优秀的算法是在选取$m$时进行额外的处理。我们可以将当前区间分块,每个块的大小为$5$(因此,会分得$(r-l+1)/5$个块),通过任何一种排序的方法选出一个块的中位数(这一操作的时间是一个常数$k$),这些中位数组成的数列可以继续分块、求中位数。不断进行操作,直到最后只有一个数时,可以证明,这个数是优秀的选择(你可以在考试结束后翻阅《算法导论》来找到这个证明)。

这一串操作里,对一个小分块选取中位数的操作是常数级别的,因此,该处理的复杂度是 $k(n/5+n/5^2+n/5^3+\cdots)=O(n)$,与数列的重排序在同一数量级,不会影响到最终的复杂度(只会在常数上有区别)。

zqy2018's solution

/*
    See the solution at https://github.com/zqy1018/sjtu_oj_solutions/blob/master/solutions/sjtu4286.md
*/
#include <cstdio>
#define INF 2000000000
using namespace std;
typedef long long ll;
int read(){
    int f = 1, x = 0;
    char c = getchar();
    while(c < '0' || c > '9'){if(c == '-') f = -f; c = getchar();}
    while(c >= '0' && c <= '9')x = x * 10 + c - '0', c = getchar();
    return f * x; 
}
int n, a[20000005], K;
int sd, ta, tb;
int gene(){
    return (sd = sd * ta + tb);
}
int BFPRT(int *a, int l, int r, int K);
template <typename T>
void swap(T& p1, T& p2){
    T tmp = p1;
    p1 = p2, p2 = tmp;
}
int get_index(int *a, int l, int r){
    //  [l, r] 插入排序
    for (int i = l + 1; i <= r; ++i){
        int tmp = a[i], pos = i - 1;
        while (pos >= l && tmp < a[pos])
            a[pos + 1] = a[pos], --pos;
        a[pos + 1] = tmp;
    }
    return (l + r) >> 1;
}
int get_pivot(int *a, int l, int r){
    // 每五个获取中位数不停的获取
    if (r - l < 4) 
        return get_index(a, l, r);
    int pos = l - 1; 
    for (int i = l; i + 4 <= r; i += 5) {
        int id = get_index(a, i, i + 4);
        swap(a[++pos], a[id]);
    }
    int len = pos - l + 1;
    return BFPRT(a, l, pos, (len >> 1) + 1);
}
int partition(int *a, int l, int r, int pivot_id){
    // 依照枢纽元划分
    swap(a[pivot_id], a[r]);
    int pos = l;
    for (int i = l; i < r; ++i)
        if (a[i] < a[r]){
            while (pos < i && a[pos] < a[r])
                ++pos;
            swap(a[pos], a[i]), ++pos;
        }
    swap(a[pos], a[r]);
    return pos;
}
int BFPRT(int *a, int l, int r, int K){
    // 主过程和快速选择差不多
    int pivot_id = get_pivot(a, l, r);
    int div_pos = partition(a, l, r, pivot_id);
    if (div_pos - l >= K)
        return BFPRT(a, l, div_pos - 1, K);
    else if (div_pos - l + 1 == K)
        return div_pos;
    else 
        return BFPRT(a, div_pos + 1, r, K - div_pos + l - 1);
}
void init(){
    n = read(), K = read();
    sd = read(), ta = read(), tb = read();
    for (int i = 1; i <= n; ++i)
        a[i] = gene();
}
void solve(){
    printf("%d\n", a[BFPRT(a, 1, n, K)]);
}
int main(){
    init();
    solve();
    return 0;
}