Naomi's notebook

Naomi's notebook

☆二項係数(ABC127E - Cell Distance)

atcoder.jp

時間内にこれを解き終わらず、さらにDでミスったため激冷えし冷め冷めになってしまった一回。リベンジしていきたいと思います。
まず、式をXとY独立に考えていいというところまでは誰でも思いつくと思います。
ABC途中の考察では色々めんどくさいことをしていたけど、解答を見たらもう少し楽にできそうなので、最初の方だけ見て再構築しました。(解答AC)

解法

Xの差について考える。Nこのマスの中から(M個まで重複ありで)K個選ぶこととなる。
まず、K個選んでしまったと仮定して、その時の距離の合計を考える。
K個のうち2個選んだ時その間の距離は、その他のK-2個とその二個の距離の上でK-2回カウントされる。つまり、その間の距離dはK回足される。
この距離dになるような2つの選び方は(N-d)*M*M通りである
ある2個(距離d)を選んだ時、それがKこの中に含まれているようなK個の選び方は、N*M-2_C_K-2で求められる。(求めるのにはN*Mのオーダー)
よって、Xの差の合計は(N*M-2_C_K-2)*((N-d)*M*M)*i を1 \leqq d \leqq N-1の範囲で足し合わせれば良い。
Yについても同じことをやって、終わり
あと10^9+7は素数だからmodしてから掛け算できるのも基本だけど大事

Combination

と思っていたらもう一つ関門があって、このサイズだとcombinationがすぐオーバーフローするので工夫しなきゃいけない…
計算途中でオーバーフローするどころか、計算結果も3007桁とかあるので、どうにかmodでうまく操作したいところです。
こんな素晴らしい記事を見つけたので貼っておきます!
qiita.com
これを参考にして実装しました
drken1215.hatenablog.com
合同式について、10^9+7は素数なので、和、差、積、べき乗に加えて商もmodが一致する、ということはわかってたけど、商は色々めんどくさいので-1乗だと考えると良いのか…

#include<cstdio>
#include<math.h>
#include<algorithm>
#include<vector>
#include<queue>
#include<string>
#include<set>
#include<cstring>

 
using namespace std;
#define rep(i,n) for(int i=0;i<n;i++)
#define INF 1001001001
#define LLINF 1001001001001001001
#define mp make_pair
#define pb push_back
#define LLIandI pair<long long int , int>
#define ll long long 
ll int N,M,K;
ll int mod=1000000000+7;

#include <iostream>
using namespace std;

const int MAX = 510000;
const int MOD = 1000000007;


// 二項係数をmodしたものを求められるクラス 計算量はnCrについてn
struct Combination{
    ll int MAX_N;
    ll int MOD;
    ll int *fac, *finv, *inv;
    Combination(ll int max_n,ll int mod): MAX_N(max_n+1),MOD(mod){ //make table
        fac = new ll int[MAX_N];finv = new ll int[MAX_N];inv = new ll int[MAX_N];
        fac[0]=fac[1]=1;
        finv[0]=finv[1]=1;
        inv[1]=1;
        for (ll int i = 2; i < MAX_N; i++){
            fac[i] = fac[i - 1] * i % MOD;
            inv[i] = MOD - inv[MOD%i] * (MOD / i) % MOD;
            finv[i] = finv[i - 1] * inv[i] % MOD;
        }
    }
    ll int COM(ll int n,ll int k){
        if (n < k) return 0;
        if (n < 0 || k < 0) return 0;
        return fac[n] * (finv[k] * finv[n - k] % MOD) % MOD;
    }
    void free_COM(){
        delete[] fac;
        delete[] finv;
        delete[] inv;
    }
    
};


int main(void){
    scanf("%lld %lld %lld",&N,&M,&K);
    ll int ans=0;
    
    //N
    Combination CB(N*M,mod);
    ll int com=CB.COM(N*M-2,K-2);
    CB.free_COM();
    for(ll int i=1;i<N;i++){
        ll int msum=((N-i)*(M*M%mod))%mod;
        ans+=(com*msum%mod)*i%mod;
        ans%=mod;
    }
    //M
    for(ll int i=1;i<M;i++){
        ll int msum=((M-i)*(N*N%mod))%mod;
        ans+=(com*msum%mod)*i%mod;
        ans%=mod;
    }
    printf("%lld\n",ans%mod);
}