Naomi's notebook

Naomi's notebook

multiset(ARC074D - 3N Numbers)

ちなみに記事のタイトルは「この問題で私が学んだこと」という基準で決めているため、実は解法に必ずしも関係があるとは限りません。

atcoder.jp


まず、数列は以下の三つに分けられる。
①前半となるN個
➁後半となるN個
③抜いたN個
全ての①は必ず全ての②の前にある。適当なそれ以前が①または③、それ以降が➁または③となるような点は、N個め以上2N個め未満の場所である。
この場所を定めた時、その前の区間とその後の区間からそれぞれ③として抜くべき個数も定まり、この時スコアを最大にするための抜き方を考えるのは容易である。(その前からはスコアが小さい順に、後ろからはスコアが大きい順に抜いていけば良い。)

あとはしゃくとり法みたいな感じで常にスコアが最大になるように③を入れ替えていけば良い。
ここを考えるのにとても時間がかかってしまった…

前半の終わりをN-1から2N-1まで動かし、
・新しく前半になった数字が、前半部分を広げる前に①に含まれていたある数字より大きい時、それを①に加え、①で一番小さかったものを追い出す(追い出したあと、①で一番小さいものは加えたやつでは無く二番目に小さかったものになる…と思いきや、結局①に加えた後に一番小さいものを求め直さなきゃいけないので注意)←WAの原因
して前半部分での得点(正)をそれぞれ求め、

後半の終わりを2NからNまで動かし、
・新しく後半になった数字が、後半部分を広げる前に➁に含まれていたある数字より小さい時、それを➁に加え、➁で一番大きかったものを追い出す
して後半部分での得点(負だが簡単のため正にする)を求め、

あとはこれらを組み合わせた時得点が最大になるものを探せば良い。

最後の追い出すところの実装にすごく時間がかかってしまった…logNで値を順序を保ったまま入れたり出したりしなければいけなかったので、二分木を使う羽目になった。もっと良いやり方もありそう。

setだと同じ数字を複数入れられないのでmultiset。ちなみにmultisetでerase(値)をするとその値の要素全てが消えてしまい悲しいことになるが、erase(ite)ならその要素だけが消える、便利。

#include<cstdio>
#include<math.h>
#include<algorithm>
#include<vector>
#include<queue>
#include<string>
#include<set>
#include<cstring>

 
using namespace std;
#define rep(i,n) for(int i=0;i<n;i++)
#define INF 1001001001
#define LLINF 1001001001001001001
#define mp make_pair
#define pb push_back
#define LLIandI pair<long long int , int>
#define ll long long

int main(void){
    int N;scanf("%d",&N);
    long long int a[300004];
    long long int t_b[300004];//前半の終わりがiだった時の前半の得点
    long long int t_a[300004];//前半の終わりがiだった時の後半の得点
    rep(i,N*3){
        scanf("%lld",&a[i]);
    }
    multiset<long long int> min1;
    rep(i,2*N){
        if(i<N){
            if(i>0)t_b[i]=a[i]+t_b[i-1];
            else t_b[i]=a[i];
            min1.insert(a[i]);
        }else{
            t_b[i]=t_b[i-1];
            auto ite=min1.begin();
            if(*ite<a[i]){
                t_b[i]+=a[i]-*ite;
                min1.erase(ite);
                min1.insert(a[i]);
            }
        }
    }
    multiset<long long int> max2;
    int ite2=N-1;
    for(int i=3*N-1;i>N-1;i--){
        if(i>=2*N){
            if(i<3*N-1)t_a[i-1]=a[i]+t_a[i-1+1];
            else t_a[i-1]=a[i];
            max2.insert(a[i]);
        }else{
            t_a[i-1]=t_a[i];
            auto ite=max2.end();
            ite--;
            if(*ite>a[i]){
                t_a[i-1]+=a[i]-*ite;
                max2.erase(ite);
                max2.insert(a[i]);
            }
        }
        
    }
    long long int ans=-LLINF;
    for(int i=N-1;i<2*N;i++){
        ans=max(ans,t_b[i]-t_a[i]);
    }
    printf("%lld\n",ans);
}