SpikeGPU  1.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
bicgstab2.h
Go to the documentation of this file.
1 
5 #ifndef SPIKE_BICGSTAB_2_H
6 #define SPIKE_BICGSTAB_2_H
7 
8 #include <vector>
9 
10 #include <cusp/blas.h>
11 #include <cusp/print.h>
12 #include <cusp/array1d.h>
13 
14 #include <spike/monitor.h>
15 #include <spike/precond.h>
16 
17 
18 namespace spike {
19 
20 typedef typename cusp::array1d<int, cusp::host_memory> IntVectorH;
21 
22 
23 template <typename T>
24 struct IsEqualTo
25 {
26  T m_val;
27 
28  IsEqualTo(T val = 0) : m_val(val) {}
29 
30  __host__ __device__
31  bool operator() (const T& val)
32  {
33  return m_val == val;
34  }
35 };
36 
37 
38 template <typename SolverVector, typename PrecVector, typename IntVector>
39 void precondSolveWrapper(SolverVector& rhs,
40  SolverVector& sol,
41  std::vector<Precond<PrecVector>*>& precond_pointers,
42  IntVector& compIndices,
43  IntVector& comp_perms,
44  std::vector<IntVector>& comp_reorderings)
45 {
46  int numComponents = comp_reorderings.size();
47 
48  for (int i = 0; i < numComponents; i++) {
49  int loc_n = comp_reorderings[i].size();
50 
51  PrecVector buffer_rhs(loc_n);
52  PrecVector buffer_sol(loc_n);
53 
54  thrust::scatter_if(rhs.begin(), rhs.end(), comp_perms.begin(), compIndices.begin(), buffer_rhs.begin(), IsEqualTo<int>(i));
55  precond_pointers[i]->solve(buffer_rhs, buffer_sol);
56  thrust::scatter(buffer_sol.begin(), buffer_sol.end(), comp_reorderings[i].begin(), sol.begin());
57  }
58 }
59 
60 
62 
68 template <typename SpmvOperator, typename SolverVector, typename PrecVector, int L>
69 void bicgstabl(SpmvOperator& spmv,
70  const SolverVector& b,
71  SolverVector& x,
72  Monitor<SolverVector>& monitor,
73  std::vector<Precond<PrecVector>*>& precond_pointers,
74  IntVectorH& compIndices,
75  IntVectorH& comp_perms,
76  std::vector<IntVectorH>& comp_reorderings)
77 {
78  typedef typename SolverVector::value_type SolverValueType;
79  typedef typename SolverVector::memory_space MemorySpace;
80 
81  typedef typename cusp::array1d<int, MemorySpace> IntVector;
82 
83  // Allocate workspace
84  int n = b.size();
85 
86  SolverValueType rou0 = SolverValueType(1);
87  SolverValueType alpha = SolverValueType(0);
88  SolverValueType omega = SolverValueType(1);
89  SolverValueType rou1;
90 
91  SolverVector r0(n);
92  SolverVector r(n);
93  SolverVector u(n,0);
94  SolverVector xx(n);
95  SolverVector Pv(n);
96 
97  IntVector loc_compIndices = compIndices;
98  IntVector loc_comp_perms = comp_perms;
99  std::vector<IntVector> loc_comp_reorderings;
100 
101  int numComponents = comp_reorderings.size();
102 
103  for (int i = 0; i < numComponents; i++)
104  loc_comp_reorderings.push_back(comp_reorderings[i]);
105 
106  std::vector<SolverVector> rr(L+1);
107  std::vector<SolverVector> uu(L+1);
108 
109  for(int k = 0; k <= L; k++) {
110  rr[k].resize(n, 0);
111  uu[k].resize(n, 0);
112  }
113 
114  SolverValueType tao[L+1][L+1];
115  SolverValueType gamma[L+2];
116  SolverValueType gamma_prime[L+2];
117  SolverValueType gamma_primeprime[L+2];
118  SolverValueType sigma[L+2];
119 
120  // r0 <- b - A * x
121  spmv(x, r0);
122  cusp::blas::axpby(b, r0, r0, SolverValueType(1), SolverValueType(-1));
123 
124  // r <- r0
125  cusp::blas::copy(r0, r);
126 
127  // uu(0) <- u
128  // rr(0) <- r
129  // xx <- x
130  thrust::copy(thrust::make_zip_iterator(thrust::make_tuple(u.begin(), x.begin(), r.begin())),
131  thrust::make_zip_iterator(thrust::make_tuple(u.end(), x.end(), r.end())),
132  thrust::make_zip_iterator(thrust::make_tuple(uu[0].begin(), xx.begin(), rr[0].begin())));
133 
134  while(!monitor.done(r)) {
135 
136  rou0 = -omega * rou0;
137 
138  monitor.increment(0.25f);
139 
140  for(int j = 0; j < L; j++) {
141  rou1 = cusp::blas::dotc(rr[j], r0);
142 
143  // return with failure
144  if(rou0 == 0)
145  return;
146 
147  SolverValueType beta = alpha * rou1 / rou0;
148  rou0 = rou1;
149 
150  for(int i = 0; i <= j; i++) {
151  // uu(i) = rr(i) - beta * uu(i)
152  cusp::blas::axpby(rr[i], uu[i], uu[i], SolverValueType(1), -beta);
153  }
154 
155  // uu(j+1) <- A * P^(-1) * uu(j);
156  // precond.solve(uu[j], Pv);
157  precondSolveWrapper(uu[j], Pv, precond_pointers, loc_compIndices, loc_comp_perms, loc_comp_reorderings);
158  spmv(Pv, uu[j+1]);
159 
160  // gamma <- uu(j+1) . r0;
161  SolverValueType gamma = cusp::blas::dotc(uu[j+1], r0);
162  if(gamma == 0)
163  return;
164 
165  alpha = rou0 / gamma;
166 
167  for(int i = 0; i <= j; i++) {
168  // rr(i) <- rr(i) - alpha * uu(i+1)
169  cusp::blas::axpy(uu[i+1], rr[i], SolverValueType(-alpha));
170  }
171 
172  // rr(j+1) = A * P^(-1) * rr(j)
173  //precond.solve(rr[j], Pv);
174  precondSolveWrapper(rr[j], Pv, precond_pointers, loc_compIndices, loc_comp_perms, loc_comp_reorderings);
175  spmv(Pv, rr[j+1]);
176 
177  // xx <- xx + alpha * uu(0)
178  cusp::blas::axpy(uu[0], xx, alpha);
179 
180  if(monitor.done(rr[0])) {
181  //precond.solve(xx, x);
182  precondSolveWrapper(xx, x, precond_pointers, loc_compIndices, loc_comp_perms, loc_comp_reorderings);
183  return;
184  }
185  }
186 
187 
188  for(int j = 1; j <= L; j++) {
189  for(int i = 1; i < j; i++) {
190  tao[i][j] = cusp::blas::dotc(rr[j], rr[i]) / sigma[i];
191  cusp::blas::axpy(rr[i], rr[j], -tao[i][j]);
192  }
193  sigma[j] = cusp::blas::dotc(rr[j], rr[j]);
194  if(sigma[j] == 0)
195  return;
196  gamma_prime[j] = cusp::blas::dotc(rr[j], rr[0]) / sigma[j];
197  }
198 
199  gamma[L] = gamma_prime[L];
200  omega = gamma[L];
201 
202  for(int j = L-1; j > 0; j--) {
203  gamma[j] = gamma_prime[j];
204  for(int i = j+1; i <= L; i++)
205  gamma[j] -= tao[j][i] * gamma[i];
206  }
207 
208  for(int j = 1; j < L; j++) {
209  gamma_primeprime[j] = gamma[j+1];
210  for(int i = j+1; i < L; i++)
211  gamma_primeprime[j] += tao[j][i] * gamma[i+1];
212  }
213 
214  // xx <- xx + gamma * rr(0)
215  // rr(0) <- rr(0) - gamma'(L) * rr(L)
216  // uu(0) <- uu(0) - gamma(L) * uu(L)
217  cusp::blas::axpy(rr[0], xx, gamma[1]);
218  cusp::blas::axpy(rr[L], rr[0], -gamma_prime[L]);
219  cusp::blas::axpy(uu[L], uu[0], -gamma[L]);
220 
221  monitor.increment(0.25f);
222 
223  if (monitor.done(rr[0])) {
224  // precond.solve(xx, x);
225  precondSolveWrapper(xx, x, precond_pointers, loc_compIndices, loc_comp_perms, loc_comp_reorderings);
226  return;
227  }
228 
229  monitor.increment(0.25f);
230 
231  // uu(0) <- uu(0) - sum_j { gamma(j) * uu(j) }
232  // xx <- xx + sum_j { gamma''(j) * rr(j) }
233  // rr(0) <- rr(0) - sum_j { gamma'(j) * rr(j) }
234  for(int j = 1; j < L; j++) {
235  cusp::blas::axpy(uu[j], uu[0], -gamma[j]);
236  cusp::blas::axpy(rr[j], xx, gamma_primeprime[j]);
237  cusp::blas::axpy(rr[j], rr[0], -gamma_prime[j]);
238 
239  if (monitor.done(rr[0])) {
240  // precond.solve(xx, x);
241  precondSolveWrapper(xx, x, precond_pointers, loc_compIndices, loc_comp_perms, loc_comp_reorderings);
242  return;
243  }
244  }
245 
246  // u <- uu(0)
247  // x <- xx
248  // r <- rr(0)
249  thrust::copy(thrust::make_zip_iterator(thrust::make_tuple(uu[0].begin(), xx.begin(), rr[0].begin())),
250  thrust::make_zip_iterator(thrust::make_tuple(uu[0].end(), xx.end(), rr[0].end())),
251  thrust::make_zip_iterator(thrust::make_tuple(u.begin(), x.begin(), r.begin())));
252 
253  monitor.increment(0.25f);
254  }
255 }
256 
257 
259 template <typename SpmvOperator, typename SolverVector, typename PrecVector>
260 void bicgstab2(SpmvOperator& spmv,
261  const SolverVector& b,
262  SolverVector& x,
263  Monitor<SolverVector>& monitor,
264  std::vector<Precond<PrecVector>*>& precond_pointers,
265  IntVectorH& compIndices,
266  IntVectorH& comp_perms,
267  std::vector<IntVectorH>& comp_reorderings)
268 {
269  bicgstabl<SpmvOperator, SolverVector, PrecVector, 2>(spmv, b, x, monitor, precond_pointers, compIndices, comp_perms, comp_reorderings);
270 }
271 
272 
274 template <typename SpmvOperator, typename SolverVector, typename PrecVector>
275 void bicgstab4(SpmvOperator& spmv,
276  const SolverVector& b,
277  SolverVector& x,
278  Monitor<SolverVector>& monitor,
279  std::vector<Precond<PrecVector>*>& precond_pointers,
280  IntVectorH& compIndices,
281  IntVectorH& comp_perms,
282  std::vector<IntVectorH>& comp_reorderings)
283 {
284  bicgstabl<SpmvOperator, SolverVector, PrecVector, 4>(spmv, b, x, monitor, precond_pointers, compIndices, comp_perms, comp_reorderings);
285 }
286 
287 
288 
289 } // namespace spike
290 
291 
292 
293 #endif
294 
void bicgstabl(SpmvOperator &spmv, const SolverVector &b, SolverVector &x, Monitor< SolverVector > &monitor, std::vector< Precond< PrecVector > * > &precond_pointers, IntVectorH &compIndices, IntVectorH &comp_perms, std::vector< IntVectorH > &comp_reorderings)
Preconditioned BiCGStab(L) Krylov method.
Definition: bicgstab2.h:69
void increment(float incr)
Definition: monitor.h:38
void precondSolveWrapper(SolverVector &rhs, SolverVector &sol, std::vector< Precond< PrecVector > * > &precond_pointers, IntVector &compIndices, IntVector &comp_perms, std::vector< IntVector > &comp_reorderings)
Definition: bicgstab2.h:39
IsEqualTo(T val=0)
Definition: bicgstab2.h:28
Definition: bicgstab2.h:24
bool done(const SolverVector &r)
Definition: monitor.h:89
__host__ __device__ bool operator()(const T &val)
Definition: bicgstab2.h:31
T m_val
Definition: bicgstab2.h:26
Spike preconditioner.
Definition: precond.h:37
void bicgstab4(SpmvOperator &spmv, const SolverVector &b, SolverVector &x, Monitor< SolverVector > &monitor, std::vector< Precond< PrecVector > * > &precond_pointers, IntVectorH &compIndices, IntVectorH &comp_perms, std::vector< IntVectorH > &comp_reorderings)
Specializations of the generic spike::bicgstabl function for L=4.
Definition: bicgstab2.h:275
cusp::array1d< int, cusp::host_memory > IntVectorH
Definition: bicgstab2.h:20
Definition: monitor.h:19
Definition of the Spike preconditioner class.
void bicgstab2(SpmvOperator &spmv, const SolverVector &b, SolverVector &x, Monitor< SolverVector > &monitor, std::vector< Precond< PrecVector > * > &precond_pointers, IntVectorH &compIndices, IntVectorH &comp_perms, std::vector< IntVectorH > &comp_reorderings)
Specializations of the generic spike::bicgstabl function for L=2.
Definition: bicgstab2.h:260