读书人

看看这个矩阵模板该如何优化(没分了)

发布时间: 2012-03-16 16:34:56 作者: rapoo

看看这个矩阵模板该怎么优化(没分了)

enum ARRAY_STORAGE_MODE
{
SET_BY_COL,
Set_BY_ROW
};

template <class T>
class CMatrix
{
public:
CMatrix(int val=0);
CMatrix(int m, int n);
CMatrix(CMatrix <T> & a);

~CMatrix(void);


bool SetMatrix(T* array,ARRAY_STORAGE_MODE mode);

int GetRows(){return m;};

int GetCols(){return n;};
void operator=(CMatrix <T> & a);
T*& operator[](int i);
CMatrix <T> operator+(CMatrix <T> & a);
CMatrix <T> operator*(double k);
CMatrix <T> operator*(CMatrix <T> & a);
CMatrix <T> CutMatrixRows(vector <int> iArray);//切掉指定行
CMatrix <T> CutMatrixBlock(int startRow,int endRow,int startCol,int endCol);//切块
CMatrix <T> Transpose();
CMatrix <T> Inverse();
int m; //行
int n; //列
T** p; //数据
};

template <class T>
CMatrix <T> ::CMatrix(int val)
{
m = 0; n = 0;
p = NULL;
}

template <class T>
CMatrix <T> ::CMatrix(int m, int n)
{
int i,j;
p = new T*[m];
for (i = 0; i < m; i++)
{
p[i] = new T[n];
for (j = 0; j < n; j++)
{p[i][j] = 0;}
}
this-> m = m; this-> n = n;
}

template <class T>
CMatrix <T> ::~CMatrix(void)
{
int i;
if (p)
{
for (i = 0; i < m; i++)
{delete[] p[i];}
delete[] p;
}
}

template <class T>
bool CMatrix <T> ::SetMatrix(T* array,ARRAY_STORAGE_MODE mode)
{
int i, j=0;
switch(mode)
{
case SET_BY_COL :
{
for (i = 0; i < n; i++)
{
for (j = 0; j < m; j++)
{p[j][i] = array[i*m+j]; }
}
break;
}
case Set_BY_ROW :
{
for (i = 0; i < m; i++)
{
for (j = 0; j < n; j++)


{
p[i][j] = array[i*n+j];
}
}

break;
}
default:
{

return false;
}
}

return true;
}

template <class T>
T*& CMatrix <T> ::operator[](int i)
{
return p[i];
}


template <class T>
CMatrix <T> CMatrix <T> ::Transpose()
{
int i, j;

CMatrix <T> t(n, m);
for (i = 0; i < m; i++)
{
for (j = 0; j < n; j++)
{
t[j][i] = p[i][j];
}
}

return t;

}


template <class T>
void CMatrix <T> ::operator=(CMatrix <T> & a)
{
int i, j;
if (p)
{
for (i = 0; i < m; i++)
{
delete[] p[i];
}
delete[] p;
}

m = a.GetRows();
n = a.GetCols();

p = new T*[m];

for (i = 0; i < m; i++)
{
p[i] = new T[n];
for (j = 0; j < n; j++)
{
p[i][j] = a[i][j];
}
}
}


template <class T>
CMatrix <T> ::CMatrix(CMatrix <T> & a)
{
int i, j;
m = a.GetRows();
n = a.GetCols();
p = new T*[m];
for (i = 0; i < m; i++)
{
p[i] = new T[n];
for (j = 0; j < n; j++)
{
p[i][j] = a[i][j];

}
}
}

template <class T>
CMatrix <T> CMatrix <T> ::operator+(CMatrix <T> & a)
{
int i, j;
CMatrix <T> t(m, n);
if (m == a.GetRows() && n == a.GetCols())
{
for (i = 0; i < m; i++)
{
for (j = 0; j < n; j++)
{
t[i][j] = p[i][j] + a[i][j];
}
}
}
return t;
}

template <class T>
CMatrix <T> CMatrix <T> ::operator*(CMatrix <T> & a)
{
int i, j, k;
CMatrix <T> t(m, a.GetCols());
if (n == a.GetRows())


{
for (i = 0; i < m; i++)
{
for (j = 0; j < a.GetCols(); j ++)
{
for (k = 0; k < n; k++)
{
t[i][j] += (p[i][k])*(a[k][j]) ;
}
}
}
}
return t;
}

template <class T>
CMatrix <T> CMatrix <T> ::CutMatrixRows(vector <int> iArray)
{
if (iArray.size()==0)
{
return *this;
}
CMatrix <T> t(m-iArray.size(),n);
int i,j,k,l,q;

k=l=q=0;
for (i=0;i <m;i++,l++)
{
if (k==iArray.size())
{
break;
}

if (i!=iArray[k])
{
for (j=0;j <n;j++,q++)
{
t[l][q] = p[i][j];
}

}
else
{
k++;
}

}

return t;
}

template <class T>
CMatrix <T> CMatrix <T> ::Inverse()//3阶求逆
{
CMatrix <T> t(*this);

int is[3];
int js[3];

float fDet = 1.0f;
int f = 1;

for (int k = 0; k < 3; k ++)
{
// µÚÒ»²½£¬È«Ñ¡Ö÷Ôª
float fMax = 0.0f;
for (int i = k; i < 3; i ++)
{
for (int j = k; j < 3; j ++)
{
const float f = fabs(t[i][j]);
if (f > fMax)
{
fMax = f;
is[k] = i;
js[k] = j;
}


}
}
if (fMax < 0.0001f)
return t;

if (is[k] != k)
{
f = -f;
swap(t[k][0], t[is[k]][0]);
swap(t[k][1], t[is[k]][1]);
swap(t[k][2], t[is[k]][2]);
}
if (js[k] != k)
{
f = -f;
swap(t[0][k], t[0][js[k]]);
swap(t[1][k], t[1][js[k]]);
swap(t[2][k], t[2][js[k]]);
}

// ¼ÆËãÐÐÁÐÖµ
fDet *= t[k][k];

// ¼ÆËãÄæ¾ØÕó

// µÚ¶þ²½
t[k][k] = 1.0f / t[k][k];
// µÚÈý²½
for (int j = 0; j < 3; j ++)
{
if (j != k)
t[k][j] *= t[k][k];
}
// µÚËIJ½
for (i = 0; i < 3; i ++)
{
if (i != k)


{
for (j = 0; j < 3; j ++)
{
if (j != k)
t[i][j] = t[i][j] - t[i][k] * t[k][j];
}
}
}
// µÚÎå²½
for (i = 0; i < 3; i++)
{
if (i!= k)
t[i][k] *= -t[k][k];
}
}

for (k = 2; k > = 0; k--)
{
if (js[k] != k)
{
swap(t[k][0], t[js[k]][0]);
swap(t[k][1], t[js[k]][1]);
swap(t[k][2], t[js[k]][2]);
}
if (is[k] != k)
{
swap(t[0][k], t[0][is[k]]);
swap(t[1][k], t[1][is[k]]);
swap(t[2][k], t[2][is[k]]);
}
}
return t;

}

template <class T>
CMatrix <T> CMatrix <T> ::CutMatrixBlock(int startRow,int endRow,int startCol,int endCol)//°´ÕÕiArrayÇеôÖ¸¶¨ÐÐ
{
CMatrix <T> t(endRow-startRow+1,endCol-startCol+1);



for (int i=0;i <endRow - startRow+1;i++)
{
for (int j=0;j <endCol - startCol+1;j++)
{
t[i][j] = p[startRow+i][startCol+j];
}

}

return t;
}
上面的类是小弟在用vc修改matlab程序时写的矩阵类(参考了网上的),发现效率很差,大约和matlab相差5倍的时间,考虑有可能是该类的问题,问问大家该怎么优化?谢谢!

[解决办法]
http://community.csdn.net/Expert/topic/5265/5265271.xml?temp=.7587549
去找现成的来用,不要自己写。
[解决办法]
矩阵相乘是一个有优化余地的东西, 有Strassen快速矩阵乘法,你去搜搜看

读书人网 >C++

热点推荐