P3390 矩阵快速幂

柔情痞子 提交于 2020-03-10 12:55:08

 

这个题根据题目也就能知道应该怎么做,但是代码怎么实现矩阵乘法,是一个问题,所以就用到了重载运算符。

重载运算符可以定义一些普通的运算,比如   + ,-,×,÷,%,<,>,!=,……有很多,但不能自己创造符号。

在这个题中,需要定义矩阵乘法,在定义之前,还要定义一个结构体:

 1 struct hls{
 2     long long s[110][110];
 3 };
 4 hls t,r;
 5 long long k;
 6 int n;
 7 const long long m=1000000007;
 8 hls operator * (const hls &a,const hls &b)
 9 {
10     hls w;
11     for(int i=1;i<=n;++i)
12     {
13         for(int j=1;j<=n;++j)
14         {
15             w.s[i][j]=0;
16         }
17     }
18     for(int x=1;x<=n;++x)
19     {
20         for(int y=1;y<=n;++y)
21         {
22             for(int z=1;z<=n;++z)
23             {
24                 w.s[x][y]+=a.s[x][z]*b.s[z][y]%m;
25                 w.s[x][y]%=m;
26             }
27         }
28     }
29     return w;
30 }

结构体中包含一个二维数组,用来表示矩阵。其中第8行之后就是定义重载运算符 * 的代码。

重载运算符语法格式:返回类型(结构体),operator ,定义的符号,后面的括号内再写相应的参数。

比如代码中的a,b。在前面要加 取址符 &,因为如果不加,在程序中就会自行复制 a‘和b’,这样就相当于又开了两个二维数组,不仅耗内存,而且浪费时间。

在大括号内(9—30行)在其中定义结构体 w,用来存储运算后的结果,首先将其清零,接下来用三个for循环,来进行矩阵的乘法运算,

w.s[x][y]就是w结构体中数组第x行,y列的位置,所以在矩阵乘法中,w.s[x][y]的结果就是  a.s的第x行依次与b.s的第y列相乘。所以用三个循环就可以定义矩阵乘法。

m的值为1e9+7,题目要求每个元素对1e9+7取模,所以乘法中应该每步都取模,防止数太大。

重载运算符定义好后,就到了快速幂:

 

 1 for(int i=1;i<=n;++i)
 2 {
 3     r.s[i][i]=1;
 4 }
 5 while(k>0) 
 6 {
 7     if(k%2==1) r=r*t;
 8     t=t*t;
 9     k/=2;
10 }

 

r是一个单位矩阵,k是指数,while中因为r,t都是hls类型的,所以会自动调用重载运算符。

最终,r 就是结果,输出 r 就可以。

完整代码:

 1 #include<iostream>
 2 using namespace std;
 3 struct hls{
 4     long long s[110][110];
 5 };
 6 hls t,r;
 7 long long k;
 8 int n;
 9 const long long m=1000000007;
10 hls operator * (const hls &a,const hls &b)
11 {
12     hls w;
13     for(int i=1;i<=n;++i)
14     {
15         for(int j=1;j<=n;++j)
16         {
17             w.s[i][j]=0;
18         }
19     }
20     for(int x=1;x<=n;++x)
21     {
22         for(int y=1;y<=n;++y)
23         {
24             for(int z=1;z<=n;++z)
25             {
26                 w.s[x][y]+=a.s[x][z]*b.s[z][y]%m;
27                 w.s[x][y]%=m;
28             }
29         }
30     }
31     return w;
32 }
33 int main()
34 {
35     cin>>n>>k;
36     for(int i=1;i<=n;++i)
37     {
38         for(int j=1;j<=n;++j)
39         {
40             cin>>t.s[i][j];
41         }
42     }
43     for(int i=1;i<=n;++i)
44     {
45         r.s[i][i]=1;
46     }
47     while(k>0) 
48     {
49        if(k%2==1) r=r*t;
50        t=t*t;
51        k/=2;
52     }
53     for(int i=1;i<=n;++i)
54     {
55         for(int j=1;j<=n;++j)
56         {
57             cout<<r.s[i][j]<<" ";
58         }
59         cout<<endl;
60     }
61     return 0;
62 }

 

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!