【题目描述】:
给定两个数字串A和B,通过将A和B进行二路归并得到一个新的数字串T,请找到字典序最小的T。
【题目解法】:
这里的二路归并并不是真正意义上的归并,因为两个数字串不是单调的
所以我们这样考虑,对于当前的数字串A B,怎样安排输出的顺序使得最后的数字串最小
显然,我们对这这两个数字串比较即可,即如果A>B,那么就输出A的第一位,否则输出B的第一位
然后我们将输出的数字剔除,用新的字符串进行比较,这样的答案就是最优的
直接这样做复杂度是O((N+M)^2)的
但是我们发现每次事实上是比较两个串的后缀,于是直接上后缀数组+RMQ即可
复杂度O((N+M)log(N+M))
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
const int N=800005;
int n,m,len,height[N],f[21][N],s[N],size;
int rank[N],sa[N],a[N],b[N],x[N],y[N],c[N];
inline long long getint()
{
long long x=0; char c=getchar(); bool flag=false;
while ((c!='-')&&((c<'0')||(c>'9'))) c=getchar();
if (c=='-') flag=true,c=getchar();
while ((c>='0')&&(c<='9')) x=x*10+(long long)(c-'0'),c=getchar();
if (flag) return -x; else return x;
}
void init()
{
n=getint(); for (int i=1; i<=n; i++) s[i]=a[i]=getint();
m=getint(); for (int i=1; i<=m; i++) s[n+1+i]=b[i]=getint();
s[n+1]=1001; len=n+m+1; size=1001;
}
void suffix_array()
{
for (int i=1; i<=len; i++) c[x[i]=s[i]]++;
for (int i=1; i<=size; i++) c[i]+=c[i-1];
for (int i=1; i<=len; i++) sa[c[x[i]]--]=i;
for (int p=1; p<len; p=p*2)
{
int tot=0;
for (int i=len-p+1; i<=len; i++) y[++tot]=i;
for (int i=1; i<=len; i++) if (sa[i]>p) y[++tot]=sa[i]-p;
for (int i=1; i<=size; i++) c[i]=0;
for (int i=1; i<=len; i++) c[x[i]]++;
for (int i=1; i<=size; i++) c[i]+=c[i-1];
for (int i=len; i>=1; i--) sa[c[x[y[i]]]--]=y[i]; size=0;
for (int i=1; i<=len; i++)
{
int u=sa[i-1],v=sa[i];
if (i==0) {y[v]=++size; continue;}
if ((x[u]==x[v])&&(x[u+p]==x[v+p])) y[v]=size; else y[v]=++size;
}
if (size>=len) break;
for (int i=1; i<=len; i++) x[i]=y[i];
}
for (int i=1; i<=len; i++) rank[sa[i]]=i;
}
void Calc()
{
int h=0;
for (int i=1; i<=len; i++)
{
if (rank[i]==1) {height[1]=0; continue;} int j=sa[rank[i]-1];
while ((i+h<=len)&&(j+h<=len)&&(s[i+h]==s[j+h])) h++;
height[rank[i]]=h; if (h>0) h--;
}
for (int i=1; i<=len; i++) f[0][i]=height[i];
for (int i=1; i<=20; i++)
for (int j=1; j<=len; j++)
if (j+(1<<i)-1<=len) f[i][j]=min(f[i-1][j],f[i-1][j+(1<<(i-1))]);
}
inline int ask(int l,int r)
{
if (l>r) swap(l,r); l++;
int k=(int)(log(r-l+1)/log(2));
return min(f[k][l],f[k][r-(1<<k)+1]);
}
void solve()
{
for (int i=1,j=1; (i<=n)||(j<=m); )
{
if (i>n) {printf("%d ",b[j++]); continue;}
if (j>m) {printf("%d ",a[i++]); continue;}
int k=ask(rank[i],rank[n+1+j]);
if (i+k-1>=n) {printf("%d ",b[j++]); continue;}
if (j+k-1>=m) {printf("%d ",a[i++]); continue;}
if (a[i+k]<b[j+k]) printf("%d ",a[i++]); else printf("%d ",b[j++]);
}
printf("\n");
}
int main()
{
init();
suffix_array();
Calc();
solve();
return 0;
}