bzoj 1208 splay set

BZOJ

题意:有一堆人和宠物。同一时刻全是人或全是宠物,少数者单向选择。要选与满意值相差最少的一个,如果有两个相等选较小的。不满意值为所选值与满意值的差。求不满意值和。


一样是随便乱跑。
md又是set比自己写的快系列。

Splay(1980K, 6460ms)

#include <cstdio>
#include <algorithm>
using namespace std;
template<class T> 
inline void read(T& x)
{
    char c = getchar(); T p = 1, n = 0;
    while(c < '0' || c > '9'){if(c == '-') p = -1; c = getchar();}
    while(c >= '0' && c <= '9'){n = n * 10 + c - '0'; c = getchar();}
    x = p * n;
}
template<class T, class U>
inline void read(T& x, U& y){read(x), read(y);}
template<class T, class U, class V>
inline void read(T& x, U& y, V& z){read(x), read(y), read(z);}

const int mod = 1000000, INF = 0x3f3f3f3f;

struct Node
{
    int v;
    int size;
    Node *ch[2];
    Node(int a = 0) : v(a), size(1) {ch[0] = ch[1] = NULL;}
} *root;

#define lc(o) (o -> ch[0])
#define rc(o) (o -> ch[1])
#define size(o) (o ? o -> size : 0)
#define v(o) (o -> v)

inline int cmp(Node *o, int k) {return size(lc(o)) == k - 1 ? -1 : (size(lc(o)) < k ? 1 : 0);}
inline void maintain(Node* &o){if(!o)return;o->size=size(lc(o))+size(rc(o))+1;}
inline void rotate(Node* &o, int d){Node*p=o->ch[d^1];o->ch[d^1]=p->ch[d];p->ch[d]=o;maintain(o);maintain(p);o=p;}

void splay(Node* &o, int k)
{
    if(!o)
        return;
    int d = cmp(o, k);
    if(d == 1)
        k -= size(lc(o)) + 1;
    if(d != -1)
    {
        Node *p = o -> ch[d];
        int d2 = cmp(p, k);
        int k2 = (d2 == 0 ? k : k - size(lc(p)) - 1);
        if(d2 != -1)
        {
            splay(p -> ch[d2], k2);
            if(d == d2)
                rotate(o, d ^ 1);
            else
                rotate(o -> ch[d], d);
        }
        rotate(o, d ^ 1);
    }
}


int t1;
void getless(Node* o, int k)
{
    if(!o) return;
    if(v(o) <= k)
    {
        t1 = v(o); getless(rc(o), k);
    }
    else
        getless(lc(o), k);
}

int t2;
void getmore(Node* o, int k)
{
    if(!o) return;
    if(v(o) >= k)
    {
        t2 = v(o); getmore(lc(o), k);
    }
    else
        getmore(rc(o), k);
}

int siz;
void insert(Node* &o, int k)
{
    if(!o)
    {
        o = new Node(k);
        siz ++;
        return;
    }
    if(v(o) <= k)
        insert(rc(o), k), siz += size(lc(o)) + 1;
    else
        insert(lc(o), k);
    maintain(o);
}

void del(Node* &o)
{
    if(!lc(o))
        o = rc(o);
    else if(!rc(o))
        o = lc(o);
    else
    {
        splay(rc(o), 1);
        lc(rc(o)) = lc(o);
        o = rc(o);
    }
    maintain(o);
}

void remove(Node* &o, int k)
{
    if(v(o) == k)
    {
        del(o);
    }
    else if(v(o) < k)
        remove(rc(o), k);
    else
        remove(lc(o), k);
    maintain(o);
}

int main()
{
    int n;
    read(n);
    int ans = 0, last;
    root = NULL;
    for(int i = 0; i < n; i++)
    {
        int a, b;
        read(a, b);
        siz = 0;

        if(!root)
        {
            last = a, insert(root, b);
            splay(root, siz);
        }
        else if(last == a)
        {
            insert(root, b);
            splay(root, siz);
        }
        else
        {
            t1 = -INF, t2 = INF;
            getless(root, b);
            getmore(root, b);
            if(b - t1 > t2 - b)
                ans += t2 - b, ans %= mod, remove(root, t2);
            else
                ans += b - t1, ans %= mod, remove(root, t1);
        }
    }
    printf("%d\n", ans);
    return 0;
}

set(956K, 152ms)

#include <cstdio>
#include <set>
#include <algorithm>
using namespace std;
template<class T> 
inline void read(T& x)
{
    char c = getchar(); T p = 1, n = 0;
    while(c < '0' || c > '9'){if(c == '-') p = -1; c = getchar();}
    while(c >= '0' && c <= '9'){n = n * 10 + c - '0'; c = getchar();}
    x = p * n;
}
template<class T, class U>
inline void read(T& x, U& y){read(x), read(y);}
template<class T, class U, class V>
inline void read(T& x, U& y, V& z){read(x), read(y), read(z);}
const int INF = 0x3f3f3f3f, mod = 1000000;
set<int> st;
int main()
{
    st.insert(INF);
    st.insert(-INF);
    int n, last = -1, ans = 0;;
    read(n);
    while(n--)
    {
        int a, b;
        read(a, b);
        if(st.size() == 2)
            last = a, st.insert(b);
        else if(last == a)
            st.insert(b);
        else
        {
            set<int>::iterator l = --st.lower_bound(b), r = st.lower_bound(b);
            if(b - *l > *r - b)
                ans += *r - b, st.erase(r), ans %= mod;
            else
                ans += b - *l, st.erase(l), ans %= mod;
        }
    }
    printf("%d\n", ans);
    return 0;
}

标签: none

添加新评论