歡迎來到Linux教程網
Linux教程網
Linux教程網
Linux教程網
Linux教程網 >> Linux編程 >> Linux編程 >> C++實現矩陣乘法

C++實現矩陣乘法

日期:2017/3/1 9:17:21   编辑:Linux編程

重載*運算符為友元函數。

#include <iostream>
#include <cmath>
using namespace std;

class Matrix{
public:
Matrix(){}
Matrix(int,int);
void setMatrix();
void showMatrix();
void showTransposedMatrix();
friend Matrix operator *(Matrix m1,Matrix m2);
protected:
int m;
int n;
int mn;
double* matrixPtr;
double* transposedMPtr;
void transpose();
};

class SquareMatrix:public Matrix{
public:
SquareMatrix(){}
SquareMatrix(int);
void setSquareMatrix();
void setDet();
void getDet();
private:
double det;
};

Matrix::Matrix(int mt,int nt){
m=mt;
n=nt;
mn=m*n;
matrixPtr=new double[mn];
}

void Matrix::setMatrix(){
cout<<"輸入矩陣的行數和列數:"<<endl;
cin>>m>>n;
mn=m*n;
matrixPtr=new double[mn];
for(int i=0;i<mn;i++)
cin>>matrixPtr[i];
}

void Matrix::transpose(){
transposedMPtr=new double[mn];
for(int i=0;i<n;i++)
for(int j=0;j<m;j++)
transposedMPtr[m*i+j]=matrixPtr[n*j+i];
}

void Matrix::showMatrix(){
for(int i=0;i<m;i++){
for(int j=0;j<n;j++)
cout<<matrixPtr[n*i+j]<<' ';
cout<<endl;
}
}

void Matrix::showTransposedMatrix(){
for(int i=0;i<n;i++){
for(int j=0;j<m;j++)
cout<<transposedMPtr[m*i+j]<<' ';
cout<<endl;
}
}

Matrix operator *(Matrix m1,Matrix m2){
Matrix m3(m1.m,m2.n);
for(int i=0;i<m3.m;i++)
for(int j=0;j<m3.n;j++){
double val=0;
for(int k=0;k<m2.m;k++)
val+=m1.matrixPtr[m1.n*i+k]*m2.matrixPtr[m2.n*k+j];
m3.matrixPtr[m3.n*i+j]=val;
}
return m3;
}

SquareMatrix::SquareMatrix(int m){
Matrix(m,m); //right?
}

void SquareMatrix::setSquareMatrix(){
cout<<"輸入方陣的階數:"<<endl;
cin>>m;
n=m;
mn=m*n;
matrixPtr=new double[mn];
for(int i=0;i<mn;i++)
cin>>matrixPtr[i];
}

void SquareMatrix::setDet(){
double valDet(double*,int);
det=valDet(matrixPtr,m);
}

void SquareMatrix::getDet(){
cout<<det<<endl;
}
double valDet( double *detPtr, int rank)
{
double val=0;
if(rank==1) return detPtr[0];
for(int i=0;i<rank;i++) //計算余子式保存在nextDetPtr[]中
{
double *nextDetPtr=new double[(rank-1)*(rank-1)];
for(int j=0;j<rank-1;j++)
for(int k=0;k<i;k++)
nextDetPtr[j*(rank-1)+k]=detPtr[(j+1)*rank+k];
for(int j=0;j<rank-1;j++)
for(int k=i;k<rank-1;k++)
nextDetPtr[j*(rank-1)+k]=detPtr[(j+1)*rank+k+1];
val+=detPtr[i]*valDet(nextDetPtr,rank-1)*pow(-1.0,i);
}
return val;
}

int main(){
Matrix m1,m2,m3;
m1.setMatrix();
m2.setMatrix();
m3=m1*m2;
m3.showMatrix();
return 0;
}

Copyright © Linux教程網 All Rights Reserved