00001
00004
00005
00006 #ifndef TBCI_SOLVER_CGS_H
00007 #define TBCI_SOLVER_CGS_H
00008
00009 #include "../basics.h"
00010
00011 #ifdef SOLVER_LOG
00012 # include <fstream>
00013 # include <iomanip>
00014 #endif
00015
00016 NAMESPACE_TBCI
00017
00018
00019
00020
00021
00022
00023
00024
00058 INST3(template <typename T, Matrix<T>, Vector<T> > class NN friend \
00059 int CGS (const Matrix<T>&, Vector<T>&, const Vector<T>&,\
00060 const Preconditioner_Sig<T, Matrix<T> >&, unsigned int&, double&, const unsigned);)
00061 INST3(template <typename T, BdMatrix<T>, Vector<T> > class NN friend \
00062 int CGS (const BdMatrix<T>&, Vector<T>&, const Vector<T>&,\
00063 const Preconditioner_Sig<T, BdMatrix<T> >&, unsigned int&, double&, const unsigned);)
00064 INST3(template <typename T, Symm_BdMatrix<T>, Vector<T> > class NN friend \
00065 int CGS (const Symm_BdMatrix<T>&, Vector<T>&, const Vector<T>&,\
00066 const Preconditioner_Sig<T, Symm_BdMatrix<T> >&, unsigned int&, double&, const unsigned);)
00067
00068
00069 template < typename T, typename SysMatrix, typename SysVector >
00070 int CGS(const SysMatrix &A, SysVector &x, const SysVector &b,
00071 const Preconditioner_Sig<T, SysMatrix> &M, unsigned int &max_iter,
00072 double &tol, const unsigned off = 0)
00073 {
00074 const unsigned int dim = A.rows();
00075 SysVector p(dim), phat(dim), q(dim), qhat(dim), vhat(dim), u(dim), uhat(dim);
00076 typename SysVector::value_type rho_1, rho_2((typename SysVector::value_type)1), alpha, beta;
00077
00078 double residsqr;
00079 double tolsqr = TBCI__ sqr (tol);
00080 double normbsqr = b.fabssqr();
00081 if (normbsqr < 1e-32)
00082 normbsqr = 1e32;
00083 else
00084 normbsqr = 1.0 / normbsqr;
00085
00086 SysVector r(b - A*x);
00087 SysVector rtilde(r);
00088
00089
00090 if ((residsqr = r.fabssqr() * normbsqr) <= tolsqr)
00091 {
00092 tol = MATH__ sqrt(residsqr);
00093 max_iter = 0;
00094 return 0;
00095 }
00096 #ifdef SOLVER_LOG
00097 STD__ ofstream cnvg;
00098 if (off) cnvg.open ("cgs_cnvg.gnu", STD__ ios::app);
00099 else cnvg.open ("cgs_cnvg.gnu");
00100 #endif
00101
00102 for (unsigned int i = 1; i <= max_iter; i++)
00103 {
00104 #ifdef SOLVER_LOG
00105 cnvg << i+off << "\t" << MATH__ sqrt(residsqr) << "\n";
00106 #endif
00107 rho_1 = dot(rtilde, r);
00108 if (rho_1 == (typename SysVector::value_type)0)
00109 {
00110 tol = MATH__ sqrt (residsqr);
00111 return 2;
00112 }
00113 if (i == 1)
00114 {
00115 u = r;
00116 p = u;
00117 } else
00118 {
00119 beta = rho_1 / rho_2;
00120 u = r + beta * q;
00121 p = u + beta * (q + beta * p);
00122 }
00123 phat = M.solve(p);
00124 vhat = A*phat;
00125 alpha = rho_1 / dot(rtilde, vhat);
00126 q = u - alpha * vhat;
00127 uhat = M.solve(u + q);
00128 x += alpha * uhat;
00129 qhat = A * uhat;
00130 r -= alpha * qhat;
00131 rho_2 = rho_1;
00132 if ((residsqr = r.fabssqr() * normbsqr) < tolsqr)
00133 {
00134 tol = MATH__ sqrt(residsqr);
00135 max_iter = i;
00136 #ifdef SOLVER_LOG
00137 cnvg << i+off << "\t" << MATH__ sqrt(residsqr) << STD__ endl; cnvg.close ();
00138 #endif
00139 return 0;
00140 }
00141 }
00142
00143 tol = MATH__ sqrt(residsqr);
00144 return 1;
00145 }
00146
00147 NAMESPACE_END
00148
00149 #endif