1 /*
  2  *   Copyright (c) 2011, Michael Lehn
  3  *
  4  *   All rights reserved.
  5  *
  6  *   Redistribution and use in source and binary forms, with or without
  7  *   modification, are permitted provided that the following conditions
  8  *   are met:
  9  *
 10  *   1) Redistributions of source code must retain the above copyright
 11  *      notice, this list of conditions and the following disclaimer.
 12  *   2) Redistributions in binary form must reproduce the above copyright
 13  *      notice, this list of conditions and the following disclaimer in
 14  *      the documentation and/or other materials provided with the
 15  *      distribution.
 16  *   3) Neither the name of the FLENS development group nor the names of
 17  *      its contributors may be used to endorse or promote products derived
 18  *      from this software without specific prior written permission.
 19  *
 20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 21  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 22  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 23  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 24  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 26  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 27  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 28  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 29  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 /* Based on
 34  *
 35        SUBROUTINE DGERFS( TRANS, N, NRHS, A, LDA, AF, LDAF, IPIV, B, LDB,
 36       $                   X, LDX, FERR, BERR, WORK, IWORK, INFO )
 37  *
 38  *  -- LAPACK routine (version 3.2) --
 39  *  -- LAPACK is a software package provided by Univ. of Tennessee,    --
 40  *  -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
 41  *     November 2006
 42  */
 43 
 44 #ifndef FLENS_LAPACK_GESV_RFS_TCC
 45 #define FLENS_LAPACK_GESV_RFS_TCC 1
 46 
 47 #include <flens/blas/blas.h>
 48 #include <flens/lapack/lapack.h>
 49 
 50 namespace flens { namespace lapack {
 51 
 52 //== generic lapack implementation =============================================
 53 template <typename MA, typename MAF, typename VPIV, typename MB, typename MX,
 54           typename VFERR, typename VBERR, typename VWORK, typename VIWORK>
 55 void
 56 rfs_generic(Transpose               trans,
 57             const GeMatrix<MA>      &A,
 58             const GeMatrix<MAF>     &AF,
 59             const DenseVector<VPIV> &piv,
 60             const GeMatrix<MB>      &B,
 61             GeMatrix<MX>            &X,
 62             DenseVector<VFERR>      &fErr,
 63             DenseVector<VBERR>      &bErr,
 64             DenseVector<VWORK>      &work,
 65             DenseVector<VIWORK>     &iwork)
 66 {
 67     using std::abs;
 68 
 69     typedef typename GeMatrix<MA>::ElementType  ElementType;
 70     typedef typename GeMatrix<MA>::IndexType    IndexType;
 71 
 72     const IndexType   itMax = 5;
 73     const ElementType Zero(0), One(1), Two(2), Three(3);
 74 
 75     const Underscore<IndexType> _;
 76 
 77     const IndexType n    = B.numRows();
 78     const IndexType nRhs = B.numCols();
 79 //
 80 //  Local Arrays
 81 //
 82     IndexType iSaveData[3] = {000};
 83     DenseVectorView<IndexType>
 84         iSave = typename DenseVectorView<IndexType>::Engine(3, iSaveData, 1);
 85 
 86 //
 87 //  Quick return if possible
 88 //
 89     if (n==0 || nRhs==0) {
 90         fErr = Zero;
 91         bErr = Zero;
 92         return;
 93     }
 94 
 95     const Transpose transT = (trans==NoTrans) ? Trans : NoTrans;
 96 //
 97 //  NZ = maximum number of nonzero elements in each row of A, plus 1
 98 //
 99     const IndexType nz = n + 1;
100     const ElementType eps = lamch<ElementType>(Eps);
101     const ElementType safeMin = lamch<ElementType>(SafeMin);
102     const ElementType safe1 = nz * safeMin;
103     const ElementType safe2 = safe1 / eps;
104 
105     auto work1 = work(_(1,n));
106     auto work2 = work(_(n+1,2*n));
107     auto work3 = work(_(2*n+1,3*n));
108 //
109 //  Do for each right hand side
110 //
111 
112     for (IndexType j=1; j<=nRhs; ++j) {
113         IndexType   count = 1;
114         ElementType lastRes = Three;
115 
116         RETRY:
117 //
118 //      Loop until stopping criterion is satisfied.
119 //
120 //      Compute residual R = B - op(A) * X,
121 //      where op(A) = A, A**T, or A**H, depending on TRANS.
122 //
123         work2 = B(_,j);
124         blas::mv(trans, -One, A, X(_,j), One, work2);
125 //
126 //      Compute componentwise relative backward error from formula
127 //
128 //      max(i) ( abs(R(i)) / ( abs(op(A))*abs(X) + abs(B) )(i) )
129 //
130 //      where abs(Z) is the componentwise absolute value of the matrix
131 //      or vector Z.  If the i-th component of the denominator is less
132 //      than SAFE2, then SAFE1 is added to the i-th components of the
133 //      numerator and denominator before dividing.
134 //
135         for (IndexType i=1; i<=n; ++i) {
136             work(i) = abs(B(i,j));
137         }
138 //
139 //      Compute abs(op(A))*abs(X) + abs(B).
140 //
141         if (trans==NoTrans) {
142             for (IndexType k=1; k<=n; ++k) {
143                 const ElementType xk = abs(X(k,j));
144                 for (IndexType i=1; i<=n; ++i) {
145                     work(i) += abs(A(i,k)) * xk;
146                 }
147             }
148         } else {
149             for (IndexType k=1; k<=n; ++k) {
150                 ElementType s = Zero;
151                 for (IndexType i=1; i<=n; ++i) {
152                     s += abs(A(i,k)) * abs(X(i,j));
153                 }
154                 work1(k) += s;
155             }
156         }
157 
158         ElementType s = Zero;
159         for (IndexType i=1; i<=n; ++i) {
160             if (work1(i)>safe2) {
161                 s = max(s, abs(work2(i))/work1(i));
162             } else {
163                 s = max(s, (abs(work2(i))+safe1)/(work1(i)+safe1));
164             }
165         }
166         bErr(j) = s;
167 //
168 //      Test stopping criterion. Continue iterating if
169 //         1) The residual BERR(J) is larger than machine epsilon, and
170 //         2) BERR(J) decreased by at least a factor of 2 during the
171 //            last iteration, and
172 //         3) At most ITMAX iterations tried.
173 //
174 
175         if (bErr(j)>eps && Two*bErr(j)<=lastRes && count<=itMax) {
176 //
177 //          Update solution and try again.
178 //
179             trs(trans, AF, piv, work2);
180             X(_,j) += work2;
181             lastRes = bErr(j);
182             ++count;
183             goto RETRY;
184         }
185 //
186 //      Bound error from formula
187 //
188 //      norm(X - XTRUE) / norm(X) .le. FERR =
189 //      norm( abs(inv(op(A)))*
190 //         ( abs(R) + NZ*EPS*( abs(op(A))*abs(X)+abs(B) ))) / norm(X)
191 //
192 //      where
193 //        norm(Z) is the magnitude of the largest component of Z
194 //        inv(op(A)) is the inverse of op(A)
195 //        abs(Z) is the componentwise absolute value of the matrix or
196 //           vector Z
197 //        NZ is the maximum number of nonzeros in any row of A, plus 1
198 //        EPS is machine epsilon
199 //
200 //      The i-th component of abs(R)+NZ*EPS*(abs(op(A))*abs(X)+abs(B))
201 //      is incremented by SAFE1 if the i-th component of
202 //      abs(op(A))*abs(X) + abs(B) is less than SAFE2.
203 //
204 //      Use DLACN2 to estimate the infinity-norm of the matrix
205 //         inv(op(A)) * diag(W),
206 //      where W = abs(R) + NZ*EPS*( abs(op(A))*abs(X)+abs(B) )))
207 //
208         for (IndexType i=1; i<=n; ++i) {
209             if (work(i)>safe2) {
210                 work(i) = abs(work2(i)) + nz*eps*work1(i);
211             } else {
212                 work(i) = abs(work2(i)) + nz*eps*work1(i) + safe1;
213             }
214         }
215 
216         IndexType kase = 0;
217         while (true) {
218             lacn2(work3, work2, iwork, fErr(j), kase, iSave);
219             if (kase==0) {
220                 break;
221             }
222             if (kase==1) {
223 //
224 //              Multiply by diag(W)*inv(op(A)**T).
225 //
226                 trs(transT, AF, piv, work2);
227                 for (IndexType i=1; i<=n; ++i) {
228                     work2(i) *= work1(i);
229                 }
230             } else {
231 //
232 //              Multiply by inv(op(A))*diag(W).
233 //
234                 for (IndexType i=1; i<=n; ++i) {
235                     work2(i) *= work1(i);
236                 }
237                 trs(trans, AF, piv, work2);
238             }
239         }
240 //
241 //      Normalize error.
242 //
243         lastRes = Zero;
244         for (IndexType i=1; i<=n; ++i) {
245             lastRes = max(lastRes, abs(X(i,j)));
246         }
247         if (lastRes!=Zero) {
248             fErr(j) /= lastRes;
249         }
250 
251     }
252 }
253 
254 //== interface for native lapack ===============================================
255 
256 #ifdef CHECK_CXXLAPACK
257 
258 template <typename MA, typename MAF, typename VPIV, typename MB, typename MX,
259           typename VFERR, typename VBERR, typename VWORK, typename VIWORK>
260 void
261 rfs_native(Transpose               trans,
262            const GeMatrix<MA>      &A,
263            const GeMatrix<MAF>     &AF,
264            const DenseVector<VPIV> &piv,
265            const GeMatrix<MB>      &B,
266            GeMatrix<MX>            &X,
267            DenseVector<VFERR>      &fErr,
268            DenseVector<VBERR>      &bErr,
269            DenseVector<VWORK>      &work,
270            DenseVector<VIWORK>     &iwork)
271 {
272     typedef typename GeMatrix<MA>::ElementType T;
273 
274     const char       TRANS   = getF77LapackChar(trans);
275     const INTEGER    N       = B.numRows();
276     const INTEGER    NRHS    = B.numCols();
277     const INTEGER    LDA     = A.leadingDimension();
278     const INTEGER    LDAF    = AF.leadingDimension();
279     const INTEGER    LDB     = B.leadingDimension();
280     const INTEGER    LDX     = X.leadingDimension();
281     INTEGER          INFO;
282 
283     if (IsSame<T,double>::value) {
284         LAPACK_IMPL(dgerfs)(&TRANS,
285                             &N,
286                             &NRHS,
287                             A.data(),
288                             &LDA,
289                             AF.data(),
290                             &LDAF,
291                             piv.data(),
292                             B.data(),
293                             &LDB,
294                             X.data(),
295                             &LDX,
296                             fErr.data(),
297                             bErr.data(),
298                             work.data(),
299                             iwork.data(),
300                             &INFO);
301     } else {
302         ASSERT(0);
303     }
304 }
305 
306 #endif // CHECK_CXXLAPACK
307 
308 //== public interface ==========================================================
309 template <typename MA, typename MAF, typename VPIV, typename MB, typename MX,
310           typename VFERR, typename VBERR, typename VWORK, typename VIWORK>
311 void
312 rfs(Transpose               trans,
313     const GeMatrix<MA>      &A,
314     const GeMatrix<MAF>     &AF,
315     const DenseVector<VPIV> &piv,
316     const GeMatrix<MB>      &B,
317     GeMatrix<MX>            &X,
318     DenseVector<VFERR>      &fErr,
319     DenseVector<VBERR>      &bErr,
320     DenseVector<VWORK>      &work,
321     DenseVector<VIWORK>     &iwork)
322 {
323     typedef typename GeMatrix<MA>::IndexType  IndexType;
324 //
325 //  Test the input parameters
326 //
327 #   ifndef NDEBUG
328     ASSERT(A.firstRow()==1);
329     ASSERT(A.firstCol()==1);
330     ASSERT(A.numRows()==A.numCols());
331 
332     const IndexType n = A.numRows();
333 
334     ASSERT(AF.firstRow()==1);
335     ASSERT(AF.firstCol()==1);
336     ASSERT(AF.numRows()==n);
337     ASSERT(AF.numCols()==n);
338 
339     ASSERT(piv.firstIndex()==1);
340     ASSERT(piv.length()==n);
341 
342     ASSERT(B.firstRow()==1);
343     ASSERT(B.firstCol()==1);
344     ASSERT(B.numRows()==n);
345 
346     const IndexType nRhs = B.numCols();
347 
348     ASSERT(X.firstRow()==1);
349     ASSERT(X.firstCol()==1);
350     ASSERT(X.numRows()==n);
351     ASSERT(X.numCols()==nRhs);
352 
353     ASSERT(fErr.firstIndex()==1);
354     ASSERT(fErr.length()==nRhs);
355 
356     ASSERT(bErr.firstIndex()==1);
357     ASSERT(bErr.length()==nRhs);
358 
359     ASSERT(work.firstIndex()==1);
360     ASSERT(work.length()==3*n);
361 
362     ASSERT(iwork.firstIndex()==1);
363     ASSERT(iwork.length()==n);
364 #   endif
365 
366 //
367 //  Make copies of output arguments
368 //
369     typename GeMatrix<MX>::NoView        X_org     = X;
370     typename DenseVector<VFERR>::NoView  fErr_org  = fErr;
371     typename DenseVector<VBERR>::NoView  bErr_org  = bErr;
372     typename DenseVector<VWORK>::NoView  work_org  = work;
373     typename DenseVector<VIWORK>::NoView iwork_org = iwork;
374 //
375 //  Call implementation
376 //
377     rfs_generic(trans, A, AF, piv, B, X, fErr, bErr, work, iwork);
378 
379 #   ifdef CHECK_CXXLAPACK
380 //
381 //  Compare results
382 //
383     typename GeMatrix<MX>::NoView        X_generic     = X;
384     typename DenseVector<VFERR>::NoView  fErr_generic  = fErr;
385     typename DenseVector<VBERR>::NoView  bErr_generic  = bErr;
386     typename DenseVector<VWORK>::NoView  work_generic  = work;
387     typename DenseVector<VIWORK>::NoView iwork_generic = iwork;
388 
389     X     = X_org;
390     fErr  = fErr_org;
391     bErr  = bErr_org;
392     work  = work_org;
393     iwork = iwork_org;
394 
395     rfs_native(trans, A, AF, piv, B, X, fErr, bErr, work, iwork);
396 
397     bool failed = false;
398     if (! isIdentical(X_generic, X, "X_generic""X")) {
399         std::cerr << "CXXLAPACK: X_generic = " << X_generic << std::endl;
400         std::cerr << "F77LAPACK: X = " << X << std::endl;
401         failed = true;
402     }
403 
404     if (! isIdentical(fErr_generic, fErr, "fErr_generic""fErr")) {
405         std::cerr << "CXXLAPACK: fErr_generic = " << fErr_generic << std::endl;
406         std::cerr << "F77LAPACK: fErr = " << fErr << std::endl;
407         failed = true;
408     }
409 
410     if (! isIdentical(bErr_generic, bErr, "bErr_generic""bErr")) {
411         std::cerr << "CXXLAPACK: bErr_generic = " << bErr_generic << std::endl;
412         std::cerr << "F77LAPACK: bErr = " << bErr << std::endl;
413         failed = true;
414     }
415 
416     if (! isIdentical(work_generic, work, "work_generic""work")) {
417         std::cerr << "CXXLAPACK: work_generic = " << work_generic << std::endl;
418         std::cerr << "F77LAPACK: work = " << work << std::endl;
419         failed = true;
420     }
421 
422     if (! isIdentical(iwork_generic, iwork, "iwork_generic""iwork")) {
423         std::cerr << "CXXLAPACK: iwork_generic = "
424                   << iwork_generic << std::endl;
425         std::cerr << "F77LAPACK: iwork = " << iwork << std::endl;
426         failed = true;
427     }
428 
429     if (failed) {
430         ASSERT(0);
431     }
432 #   endif
433 }
434 
435 //-- forwarding ----------------------------------------------------------------
436 template <typename MA, typename MAF, typename VPIV, typename MB, typename MX,
437           typename VFERR, typename VBERR, typename VWORK, typename VIWORK>
438 void
439 rfs(Transpose    trans,
440     const MA     &A,
441     const MAF    &AF,
442     const VPIV   &piv,
443     const MB     &B,
444     MX           &&X,
445     VFERR        &&fErr,
446     VBERR        &&bErr,
447     VWORK        &&work,
448     VIWORK       &&iwork)
449 {
450     CHECKPOINT_ENTER;
451     rfs(trans, A, AF, piv, B, X, fErr, bErr, work, iwork);
452     CHECKPOINT_LEAVE;
453 }
454 
455 } } // namespace lapack, flens
456 
457 #endif // FLENS_LAPACK_GESV_RFS_TCC