【题目描述】:
给定方程
X1+X2+. +Xn=M
我们对第l..N1个变量进行一些限制:
Xl < = A
X2 < = A2
Xn1 < = An1
我们对第n1 + 1..n1+n2个变量进行一些限制:
Xn1+l > = An1+1
Xn1+2 > = An1+2
Xnl+n2 > = Anl+n2
求:在满足这些限制的前提下,该方程正整数解的个数。
答案可能很大,请输出对p取模后的答案,也即答案除以p的余数。
【题目解法】:
考虑对于n个未知数和为m的正整数解的组数,根据插板法答案就是C(m-1,n-1)
现在考虑有限制的情况
如果限制是>=,那么我们将m减去(A-1),然后这个未知数就变成正整数限制了
考虑限制是<=,那么直接做不好求,我们不妨考虑>的情况,显然我们把m减去A,就变成正整数限制了
然后把不符合的情况减去即可
当然在做的过程中可能会出现组合数对任意数取模,因此需要扩展Lucas
复杂度O(2^N*(logM/logP))
写了一下午TAT
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
const int N=100005;
int cases,n,m,n1,n2,limit[N],tot,a[N],b[N];
long long maxmod,c[N],r[N],ans;
inline long long getint()
{
long long x=0; char c=getchar(); bool flag=false;
while ((c!='-')&&((c<'0')||(c>'9'))) c=getchar();
if (c=='-') flag=true,c=getchar();
while ((c>='0')&&(c<='9')) x=x*10+(long long)(c-'0'),c=getchar();
if (flag) return -x; else return x;
}
inline long long calc(long long x,int y,long long maxmod)
{
long long nowans=1;
for (;y;y/=2,x=x*x%maxmod) if (y&1) nowans=(nowans*x)%maxmod;
return nowans;
}
namespace Prime
{
int prime[N],cnt; bool v[N];
void init()
{
cnt=0; memset(v,0,sizeof(v));
for (int i=2; i<N; i++)
{
if (!v[i]) prime[++cnt]=i;
for (int j=1; (j<=cnt)&&(i*prime[j]<N); j++)
{
v[i*prime[j]]=true;
if (i%prime[j]==0) break;
}
}
}
void solve()
{
init(); int x=maxmod;
for (int i=1; (i<=cnt)&&(x!=1); i++)
{
if (x%prime[i]!=0) continue; a[++tot]=prime[i]; c[tot]=1;
for (;x%prime[i]==0;x/=prime[i],b[tot]++);
for (int j=1; j<=b[tot]; j++) c[tot]*=prime[i];
}
if (x>1) a[++tot]=x,b[tot]=1,c[tot]=x;
}
}
namespace CRT
{
long long x0,y0;
long long gcd(long long a,long long b)
{
if (b==0) {x0=1; y0=0; return a;}
long long d=gcd(b,a%b);
long long t=x0; x0=y0; y0=t-(a/b)*y0;
return d;
}
long long inv(long long x,long long nowmod)
{
long long d=gcd(x,nowmod);
x0*=d; x0=(x0%nowmod+nowmod)%nowmod;
return x0;
}
long long solve()
{
long long x=c[1],y=r[1];
for (int i=2; i<=tot; i++)
{
long long rest=r[i]-y,d=gcd(x,c[i]);
if (rest%d!=0) return -1; x0=x0*(rest/d);
x0=(x0%(c[i]/d)+c[i]/d)%(c[i]/d);
y+=x*x0; x=x*c[i]/d; y=(y%x+x)%x;
}
return (y%maxmod+maxmod)%maxmod;
}
}
namespace Lucas
{
struct node {long long a,b;};
long long p,c,nowmod;
inline long long fac(int m,int n)
{
long long nowans=1;
for (int i=m; i<=n; i++) if (i%p!=0) nowans=nowans*i%nowmod;
return nowans;
}
node dfs(int n)
{
if (n<=1) return (node){1,0};
node nowans=dfs(n/p); nowans.b=nowans.b+n/p;
long long tmp1=fac(1,nowmod-1),tmp2=fac((n/nowmod)*nowmod+1,n);
nowans.a=nowans.a*calc(tmp1,n/nowmod,nowmod)%nowmod*tmp2%nowmod;
return nowans;
}
inline long long solve(int n,int m)
{
if (n<m) return 0; if (n==m) return 1; if (m==0) return 1;
node A=dfs(n),B=dfs(m),C=dfs(n-m);
return calc(p,A.b-B.b-C.b,nowmod)*A.a%nowmod*CRT::inv(B.a,nowmod)%maxmod*CRT::inv(C.a,nowmod)%nowmod;
}
}
inline long long C(int n,int m)
{
if (n<m) return 0;
if (n<0) return 0;
if (m<0) return 0;
for (int i=1; i<=tot; i++)
{
Lucas::p=a[i]; Lucas::c=b[i]; Lucas::nowmod=c[i];
r[i]=Lucas::solve(n,m);
}
return CRT::solve();
}
void init()
{
n=getint(); n1=getint(); n2=getint(); m=getint();
for (int i=1; i<=n1; i++) limit[i]=getint();
for (int i=1; i<=n2; i++) m-=(getint()-1);
ans=0;
}
void dfs(int x,int sum,int cnt)
{
if (x>n1)
{
if (cnt%2==0) ans+=C(m-sum-1,n-1);
else ans-=C(m-sum-1,n-1);
ans=(ans%maxmod+maxmod)%maxmod;
return;
}
dfs(x+1,sum,cnt);
dfs(x+1,sum+limit[x],cnt+1);
}
int main()
{
cases=getint(); maxmod=getint(); Prime::solve();
while (cases--)
{
init();
dfs(1,0,0);
printf("%lld\n",ans);
}
return 0;
}