题意:

题解:
询问l~r的答案,很显然我们直接转化成0~r的答案减去0~(l-1)的答案。
如何求0~a的答案,从高位往低位枚举。
(1)如果当前位置在a中是1,那么这个位置可以取0,也可以取1。如果取0,那后面的随便取都无所谓,也就是说,后面的位置的0,1是可以随便乱放的,也就是说当前这一位为0^(x的这一位),然后之后位随便取,恰好是一个区间,可以求区间的答案。接下来把这一位取1^(x的这一位)然后继续往低位走。
(2)如果当前位置在a中是0,那么这个位置只能取0,就把这一位取0^(x的这一位),接着走。
接下来就是快速查询一段区间的答案,也就是一段区间的f(i)值。方法很多,可以二分直接做,然而我写了颗线段树维护。。。
#include<cstdio>
#include<algorithm>
#include<cstdlib>
using namespace std;
const int mod=998244353,INF=(1<<30)-1;
int n,q,a[100002],cnt=1;
typedef struct{
int ls,rs;
long long sum,f;
}P;
P p[10000002];
void pushdown(int root,int begin,int mid,int end){
if (p[root].f)
{
if (!p[root].ls)p[root].ls=++cnt;
if (!p[root].rs)p[root].rs=++cnt;
p[p[root].ls].sum=(p[p[root].ls].sum+(mid-begin+1)*p[root].f)%mod;
p[p[root].rs].sum=(p[p[root].rs].sum+(end-mid)*p[root].f)%mod;
p[p[root].ls].f=(p[p[root].ls].f+p[root].f)%mod;
p[p[root].rs].f=(p[p[root].rs].f+p[root].f)%mod;
p[root].f=0;
}
}
void gx(int root,int begin,int end,int begin2,int end2,long long z){
if (begin>=begin2 && end<=end2)
{
p[root].sum=(p[root].sum+z*(end-begin+1))%mod;p[root].f=(p[root].f+z)%mod;
return;
}
int mid=(begin+end)/2;pushdown(root,begin,mid,end);
if (!(begin>end2 || mid<begin2))
{
if (!p[root].ls)p[root].ls=++cnt;
gx(p[root].ls,begin,mid,begin2,end2,z);
}
if (!(mid+1>end2 || end<begin2))
{
if (!p[root].rs)p[root].rs=++cnt;
gx(p[root].rs,mid+1,end,begin2,end2,z);
}
p[root].sum=(p[p[root].ls].sum+p[p[root].rs].sum)%mod;
}
long long cx(int root,int begin,int end,int begin2,int end2){
if (begin>end2 || end<begin2 || !root)return 0;
if (begin>=begin2 && end<=end2)return p[root].sum;
int mid=(begin+end)/2;pushdown(root,begin,mid,end);
return (cx(p[root].ls,begin,mid,begin2,end2)+cx(p[root].rs,mid+1,end,begin2,end2))%mod;
}
int js(int x,int z){
if (x<0)return 0;
int rt=1,ans=0,t=0;
for (int i=30;i>=0;i--)
if ((1<<i)&x)
{
bool u;
if ((1<<i)&z)u=1;else u=0;
ans=(ans+cx(1,0,INF,t+(1<<i)*u,t+(1<<i)*u+((1<<i)-1)))%mod;
t+=(1<<i)*(u^1);
}
else
{
bool u;
if ((1<<i)&z)u=1;else u=0;
t+=(1<<i)*u;
}
ans=(ans+cx(1,0,INF,t,t))%mod;
return ans;
}
int main()
{
scanf("%d%d",&n,&q);
for (int i=1;i<=n;i++)scanf("%d",&a[i]);
sort(a+1,a+n+1);
a[n+1]=INF+1;
for (int i=1;i<=n;i++)
if (a[i]!=a[i+1])gx(1,0,INF,a[i],a[i+1]-1,(long long)i*i%mod);
for (int i=1;i<=q;i++)
{
int l,r,x;
scanf("%d%d%d",&l,&r,&x);
printf("%d\n",((js(r,x)-js(l-1,x))%mod+mod)%mod);
}
return 0;
}
来源:https://www.cnblogs.com/1124828077ccj/p/12271694.html