This page looks best with JavaScript enabled

「NOI2004」郁闷的出纳员-Splay

 ·  ✏️ About  1475 words  ·  ☕ 3 mins read · 👀... views

维护一个数列。

现有四种命令,新加入一个数 $k$ ,把每个数加上 $k$ ,把每个数减去 $k$ ,查询第 $k$ 大的数。如果数列中的任意数小于 $min$ ,将它立即删除。并在最后输出总共删去的数的个数 $res$ 。

如果新加入的数 $k$ 的初值小于 $min$ ,它将不会被加入数列。

链接

Luogu P1486

题解

这是一道经典的平衡树的题,被我用来练手Splay。

写完这道题之后我就觉得,我再也不会想用Splay了。debug了一天,简直浑身难受。以后尽量写旋转&非旋Treap吧。

真香 !!!

构建一颗Splay树。需要记录目前已经全体加过或者减过的数,也就是一个相对值。换算来说就是 树外-相对值=树内,树内+相对值=树外 。后面也就不再太多特殊说明。需要添加两个虚的最大和最小节点,也会导致排名计算的一些变化。

  • 插入操作

先判断是否满足插入条件,即此数是否大于 $min$ ,然后减去相对后正常插入,splay 至根节点。

  • 加上一个数

直接更改全局相对值,由于不会出现删数,不会有其他操作。

  • 减去一个数

首先更改全局相对值,再把小于 $min$ 的数删除,简单的来说就是吧第一个大于等于 $min$ 的数 splay 到根上,然后删除左子树,补上左边的最小节点。

如果正好存在值为 $min$ 的节点,就将它直接 splay 到根,完成上述操作;如果不存在,就插入一个值为 $min-1$ 的节点,寻找它的后继,并 splay 到根,完成上述操作。这时统计 $res$ 需要减去我们刚刚加上的节点。

  • 查询第 k 大

直接查,然后 splay 到根。只需要注意我们的数列是从小到大排列的。

代码

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#include <cstdio>
#define MAX 0x3f3f3f3f
using namespace std;

inline int qr(){
    int f = 1,s = 0;char ch = getchar();
    while(ch<'0'||ch>'9'){
        if(ch == '-') f = -1;
        ch = getchar();
    }
    while(ch>='0'&&ch<='9'){
        s = (s<<3)+(s<<1)+ch-48;
        ch = getchar();
    }
    return f*s;
}

struct splay_t{
    struct node_t{
        int val,size,cnt;
        node_t *son[2],*p;node_t **null,**root;
        //与父亲关系
        inline bool get_p(){return p->son[1] == this;}
        //双向连接
        inline void link(node_t *dst,bool re){p = dst;dst->son[re] = this;}
        //更新size值
        inline void update(){size = son[0]->size + son[1]->size + cnt;}
        //初始化**root和**null
        inline void init(node_t **null,node_t **root){this->null = null,this->root = root;}
        //获取左右节点的大小
        inline int lsize(){return son[0]->size;}int rsize(){return son[1]->size;}
        //寻找节点前驱或者后继
        node_t *uporlow(int tmp){//0前驱,1后继
            splay();
            node_t *t = son[tmp];
            while(t->son[1-tmp] != *null)
                t = t->son[1-tmp]; 
            return t;
        }
        //旋转
        void rotate(){
            bool re = get_p();node_t *rp = p;
            link(rp->p,rp->get_p());
            son[1-re]->link(rp,re);
            rp->link(this,1-re);
            rp->update();update();
            if(p == *null) *root = this; 
        }
        //splay操作
        node_t* splay(node_t *tar = NULL){
            if(this == *null) return this;
            if(tar == NULL) tar = *null;
            while(p!=tar){
                if(p->p == tar) rotate();
                else{
                    if(p->get_p()==get_p()) p->rotate(),rotate();
                    else rotate(),rotate();
                }
            }
            return this;
        }
    };
    int treecnt;
    node_t pool[300000];
    node_t *null,*root,*lb,*rb;//lb是左边的虚拟节点,rb同理
    //初始化
    splay_t(){
        treecnt = 0;
        newnode(null);root = null;
        null->size = 0,null->val = 0;
        lb = insert(-MAX);rb = insert(MAX);
    }
    //新建节点
    void newnode(node_t *&r,int val = 0){
        r = &pool[treecnt++];
        r->val = val;
        r->son[0] = r->son[1] = r->p = null;
        r->cnt = r->size = 1;
        r->init(&null,&root);
    }
    //寻找给定rank的数字
    node_t* find_Kth(int rank){
        node_t *t = root;
        while(t!=null){
            if(rank<t->lsize())
                t = t->son[0];
            else if((rank-=t->lsize())<t->cnt)
                return t->splay();
            else
                rank-=t->cnt,t = t->son[1];
        }
        return null;
    }
    //按值寻找
    node_t *find_by_val(int val){
        node_t *t = root;
        while(t!=null){
            if(val<t->val)
                t = t->son[0];
            else if(val==t->val)
                return t->splay();
            else
                t = t->son[1];
        }
        return null;
    }
    //插入给定值的节点
    node_t* insert(int val){
        node_t **tar = &root,*parent = null;
        while(*tar!=null){
            (*tar)->size++;
            if((*tar)->val == val){
                (*tar)->cnt++;return *tar;
            }
            else{
                parent = *tar;tar = &(*tar)->son[(*tar)->val<val];
            }
        }
        newnode(*tar,val);
        (*tar)->link(parent,parent->val < val);
        return (*tar)->splay();
    }
    //调试用 打印树
    void print(node_t *r = NULL,int depth = 0){
        if(r == NULL) r = root;
        if(r == null) return;
        else{
            print(r->son[0],depth+1);
            for(int i = 0;i<depth;i++) putchar(' ');
            printf("v:%d,size:%d,cnt:%d,son:%d %d,depth:%03d\n",r->val,r->size,r->cnt,r->son[0]!=null,r->son[1]!=null,depth);
            print(r->son[1],depth+1);
        }
    }
};

splay_t x;int n,minn,res = 0,nowadd = 0;

//插入一个数
inline void insert(int val){if(val>=minn) x.insert(val-nowadd);}//注意要减去nowadd 
//统一加工资
inline void add(int val){nowadd+=val;}
//统一减公司顺便裁人
inline void decrease(int val){
    nowadd-=val;
    splay_t::node_t *r = x.find_by_val(minn-nowadd);//注意要减去nowadd 
    if(r!=x.null)
        r->splay(),res+=(x.root->lsize()-1);
    else
        x.insert(minn-nowadd-1)->uporlow(1)->splay(),res+=(x.root->lsize()-2);
    x.lb->link(x.root,0);x.lb->son[1] = x.null;
    x.root->update();
}
//查找工资排名K位的员工的工资
inline int ask(int rank){
    if(rank > x.root->size - 2) return -1;
    return x.find_Kth(x.root->size-rank-1)->val + nowadd;//注意要加上nowadd
}

int main(){
    n = qr();minn = qr();
    for(int i = 0;i<n;i++){
        char op[20];int k;
        scanf("%s",op);k = qr();
        if(op[0] == 'A')      add(k);
        else if(op[0] == 'S') decrease(k);
        else if(op[0] == 'I') insert(k);
        else if(op[0] == 'F') printf("%d\n",ask(k));
        else if(op[0] == 'P') x.print();
    }
    printf("%d\n",res);
    return 0;
}

cqqqwq
WRITTEN BY
cqqqwq
A student in Software Engineering.


Comments