「题解」
先分组再给组标号,即要求组的大小从小到大,分完以后将方案乘上m!
对于20%的数据,f[i][j][k]表示前i个人分j组,最后一组有k个人,n^4 递推
对于50%的数据
我们发现上一个算法的瓶颈在于,要考虑前一组分了多少个人,这样很麻烦
我们可以考虑只用f[i][j]表示前i个人分j组,之后的m-j组先分配和最后一组同样的人数,那么在考虑后面分组的时候,就不用考虑递增的问题了
这样直接递推是nm^2
其实很容易发现,m*(m+1)/2显然无解,那么m^2应当与n同级,直接特判完dp即可
对于100%的数据
上一个递推算法可以用一些技巧优化到nm,不过比较麻烦
我们可以考虑,首先给第i个小队分配i个人,接下来为了保证人数递增,每次给第i至n个小队增加1个人。
那么答案相当于将n-m*(m+1)/2做1-m的自然数拆分,这个递推比较简单。
复杂度n*m
20%
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
|
#include<map> #include<set> #include<cmath> #include<stack> #include<queue> #include<cstdio> #include<vector> #include<cstring> #include<cstdlib> #include<iostream> #include<algorithm> using namespace std; #define mod 998244353 #define pi acos(-1) #define inf 0x7fffffff #define ll long long using namespace std; ll read() { ll x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } int n,m,ans; int f[25][25][25]; //i队,共j人,最后一队k人 int main() { n=read();m=read(); f[0][0][0]=1; for(int i=0;i<=m;i++) for(int j=0;j<=n;j++) for(int last=0;last<=j;last++) for(int k=last+1;j+k<=n;k++) f[i+1][j+k][k]+=f[i][j][last]; for(int i=0;i<=n;i++) ans+=f[m][n][i]; for(int i=1;i<=m;i++) ans=(ans*i)%mod; printf("%d\n",ans); return 0; } |
50%
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
|
#include<map> #include<set> #include<cmath> #include<stack> #include<queue> #include<cstdio> #include<vector> #include<cstring> #include<cstdlib> #include<iostream> #include<algorithm> using namespace std; #define mod 998244353 #define pi acos(-1) #define inf 0x7fffffff #define ll long long using namespace std; ll read() { ll x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } int n,m; int f[3005][105]; int main() { n=read();m=read(); if(n<(ll)m*(m+1)/2) puts("0"); else { f[0][0]=1; for(int i=0;i<=n;i++) for(int j=0;j<m;j++) for(int k=1;i+k*(m-j)<=n;k++) { f[i+k*(m-j)][j+1]=(f[i+k*(m-j)][j+1]+f[i][j])%mod; } for(int i=1;i<=m;i++) f[n][m]=((ll)f[n][m]*i)%mod; printf("%d\n",f[n][m]); } return 0; } |
100%
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
|
#include<map> #include<set> #include<cmath> #include<stack> #include<queue> #include<cstdio> #include<vector> #include<cstring> #include<cstdlib> #include<iostream> #include<algorithm> using namespace std; #define mod 998244353 #define pi acos(-1) #define inf 0x7fffffff #define ll long long using namespace std; ll read() { ll x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } int n,m; int f[100005]; int main() { n=read();m=read(); if(n<(ll)m*(m+1)/2) puts("0"); else { n-=m*(m+1)/2; f[0]=1; for(int i=1;i<=m;i++) for(int j=i;j<=n;j++) f[j]=(f[j]+f[j-i])%mod; for(int i=1;i<=m;i++) f[n]=(ll)f[n]*i%mod; printf("%d\n",f[n]); } return 0; } |