00001
00004
00005
00006 #ifndef TBCI_SOLVER_BICG_H
00007 #define TBCI_SOLVER_BICG_H
00008
00009 #include "../basics.h"
00010
00011 #if !defined(NO_GD) && !defined(AUTO_DECL)
00012 # include "bicg_gd.h"
00013 #endif
00014
00015 NAMESPACE_TBCI
00016
00017 INST3(template <typename T, Matrix<T>, Vector<T> > class NN friend \
00018 int BiCG(const Matrix<T>&, Vector<T>&, const Vector<T>&,\
00019 const Preconditioner_Sig<T, Matrix<T> >&, unsigned int&, double&);)
00020 INST3(template <typename T, BdMatrix<T>, Vector<T> > class NN friend \
00021 int BiCG(const BdMatrix<T>&, Vector<T>&, const Vector<T>&,\
00022 const Preconditioner_Sig<T, BdMatrix<T> >&, unsigned int&, double&);)
00023
00045
00046 template <typename T, typename SysMatrix, typename SysVector>
00047 int BiCG(const SysMatrix& A, SysVector& x, const SysVector& b,
00048 const Preconditioner_Sig<T, SysMatrix>& M, unsigned int& max_iter, double& tol)
00049 {
00050 unsigned int dim = A.rows();
00051 typename SysVector::value_type rho_1, rho_2(0), alpha, beta;
00052 SysVector z(dim), ztilde(dim), p(dim), ptilde(dim), q(dim), qtilde(dim);
00053
00054 double residsqr;
00055 const double tolsqr = sqr(tol);
00056 double normbsqr = b.fabssqr();
00057
00058 if (normbsqr < 1e-32)
00059 normbsqr = 1e32;
00060 else
00061 normbsqr = 1.0 / normbsqr;
00062
00063 SysVector r(b - A * x);
00064 SysVector rtilde(conj(r));
00065
00066 if ((residsqr = r.fabssqr() * normbsqr) <= tolsqr) {
00067 tol = MATH__ sqrt(residsqr);
00068 max_iter = 0;
00069 return 0;
00070 }
00071
00072 for (unsigned int i = 1; i <= max_iter; ++i) {
00073 z = M.solve(r);
00074 ztilde = M.transSolve(rtilde);
00075 rho_1 = z * rtilde;
00076 if (rho_1 == (typename SysVector::value_type)0) {
00077 tol = MATH__ sqrt(r.fabssqr() * normbsqr);
00078 max_iter = i;
00079 return 2;
00080 }
00081 if (UNLIKELY(i == 1)) {
00082 p = z;
00083 ptilde = ztilde;
00084 } else {
00085 beta = rho_1 / rho_2;
00086 p = z + beta * p;
00087 ptilde = ztilde + beta * ptilde;
00088 }
00089
00090 q = A * p;
00091 qtilde = A.transMult(ptilde);
00092 alpha = rho_1 / (q * ptilde);
00093 x += alpha * p;
00094 r -= alpha * q;
00095 rtilde -= alpha * qtilde;
00096
00097 rho_2 = rho_1;
00098 if ((residsqr = r.fabssqr() * normbsqr) < tolsqr) {
00099 tol = MATH__ sqrt(residsqr);
00100 max_iter = i;
00101 return 0;
00102 }
00103 }
00104
00105 tol = MATH__ sqrt(residsqr);
00106 return 1;
00107 }
00108
00109 NAMESPACE_END
00110
00111 #endif