1 /*
  2  *   Copyright (c) 2011, Michael Lehn
  3  *
  4  *   All rights reserved.
  5  *
  6  *   Redistribution and use in source and binary forms, with or without
  7  *   modification, are permitted provided that the following conditions
  8  *   are met:
  9  *
 10  *   1) Redistributions of source code must retain the above copyright
 11  *      notice, this list of conditions and the following disclaimer.
 12  *   2) Redistributions in binary form must reproduce the above copyright
 13  *      notice, this list of conditions and the following disclaimer in
 14  *      the documentation and/or other materials provided with the
 15  *      distribution.
 16  *   3) Neither the name of the FLENS development group nor the names of
 17  *      its contributors may be used to endorse or promote products derived
 18  *      from this software without specific prior written permission.
 19  *
 20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 21  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 22  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 23  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 24  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 26  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 27  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 28  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 29  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 /*
 34        SUBROUTINE DLATRS( UPLO, TRANS, DIAG, NORMIN, N, A, LDA, X, SCALE,
 35       $                   CNORM, INFO )
 36  *
 37  *  -- LAPACK auxiliary routine (version 3.2) --
 38  *  -- LAPACK is a software package provided by Univ. of Tennessee,    --
 39  *  -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
 40  *     November 2006
 41  */
 42 
 43 #ifndef FLENS_LAPACK_AUX_LATRS_TCC
 44 #define FLENS_LAPACK_AUX_LATRS_TCC 1
 45 
 46 #include <flens/blas/blas.h>
 47 #include <flens/lapack/lapack.h>
 48 
 49 namespace flens { namespace lapack {
 50 
 51 //== generic lapack implementation =============================================
 52 template <typename MA, typename VX, typename SCALE, typename CNORM>
 53 void
 54 latrs_generic(Transpose             trans,
 55               bool                  normIn,
 56               const TrMatrix<MA>    &A,
 57               DenseVector<VX>       &x,
 58               SCALE                 &scale,
 59               DenseVector<CNORM>    &cNorm)
 60 {
 61     using std::abs;
 62     using std::max;
 63     using std::min;
 64 
 65     typedef typename DenseVector<VX>::ElementType  T;
 66     typedef typename DenseVector<VX>::IndexType    IndexType;
 67 
 68     const T Zero(0), Half(0.5), One(1);
 69     const Underscore<IndexType>  _;
 70     const IndexType n = A.dim();
 71     const bool upper = (A.upLo()==Upper);
 72 //
 73 //  Quick return if possible
 74 //
 75     if (n==0) {
 76         return;
 77     }
 78 //
 79 //  Determine machine dependent parameters to control overflow.
 80 //
 81     const T smallNum = lamch<T>(SafeMin) / lamch<T>(Precision);
 82     const T bigNum = One / smallNum;
 83     scale = One;
 84 
 85     if (!normIn) {
 86 //
 87 //      Compute the 1-norm of each column, not including the diagonal.
 88 //
 89         if (upper) {
 90 //
 91 //          A is upper triangular.
 92 //
 93             for (IndexType j=1; j<=n; ++j) {
 94                 cNorm(j) = blas::asum(A(_(1,j-1),j));
 95             }
 96         } else {
 97 //
 98 //          A is lower triangular.
 99 //
100             for (IndexType j=1; j<=n-1; ++j) {
101                 cNorm(j) = blas::asum(A(_(j+1,n),j));
102             }
103             cNorm(n) = Zero;
104         }
105     }
106 //
107 //  Scale the column norms by TSCAL if the maximum element in CNORM is
108 //  greater than BIGNUM.
109 //
110     const IndexType iMax = blas::iamax(cNorm);
111     const T tMax = cNorm(iMax);
112 
113     T tScale;
114     if (tMax<=bigNum) {
115         tScale = One;
116     } else {
117         tScale = One / (smallNum*tMax);
118         cNorm *= tScale;
119     }
120 //
121 //  Compute a bound on the computed solution vector to see if the
122 //  Level 2 BLAS routine DTRSV can be used.
123 //
124     IndexType jFirst, jLast, jEnd, jInc;
125 
126     const IndexType j = blas::iamax(x);
127     T xMax = abs(x(j));
128 
129     T xBound = xMax;
130     T grow;
131 
132     if (trans==NoTrans) {
133 //
134 //      Compute the growth in A * x = b.
135 //
136         if (upper) {
137             jFirst = n;
138             jLast  = 1;
139             jInc   = -1;
140         } else {
141             jFirst = 1;
142             jLast  = n;
143             jInc   = 1;
144         }
145         jEnd = jLast + jInc;
146 
147         if (tScale!=One) {
148             grow = Zero;
149         } else {
150             if (A.diag()==NonUnit) {
151 //
152 //              A is non-unit triangular.
153 //
154 //              Compute GROW = 1/G(j) and XBND = 1/M(j).
155 //              Initially, G(0) = max{x(i), i=1,...,n}.
156 //
157                 grow = One / max(xBound, smallNum);
158                 xBound = grow;
159                 bool grothFactorTooSmall = false;
160                 for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
161 //
162 //                  Exit the loop if the growth factor is too small.
163 //
164                     if (grow<=smallNum) {
165                         grothFactorTooSmall = true;
166                         break;
167                     }
168 //
169 //                  M(j) = G(j-1) / abs(A(j,j))
170 //
171                     const T  Tjj = abs(A(j,j));
172                     xBound = min(xBound, min(One,Tjj)*grow);
173                     if (Tjj+cNorm(j)>=smallNum) {
174 //
175 //                      G(j) = G(j-1)*( 1 + CNORM(j) / abs(A(j,j)) )
176 //
177                         grow *= Tjj / (Tjj+cNorm(j));
178                     } else {
179 //
180 //                      G(j) could overflow, set GROW to 0.
181 //
182                         grow = Zero;
183                     }
184                 }
185                 if (!grothFactorTooSmall) {
186                     grow = xBound;
187                 }
188             } else {
189 //
190 //              A is unit triangular.
191 //
192 //              Compute GROW = 1/G(j), where G(0) = max{x(i), i=1,...,n}.
193 //
194                 grow = min(One, One/max(xBound, smallNum));
195                 for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
196 //
197 //                  Exit the loop if the growth factor is too small.
198 //
199                     if (grow<=smallNum) {
200                         break;
201                     }
202 //
203 //                  G(j) = G(j-1)*( 1 + CNORM(j) )
204 //
205                     grow *= One / (One+cNorm(j));
206                 }
207             }
208         }
209     } else {
210 //
211 //      Compute the growth in A**T * x = b.
212 //
213         if (upper) {
214             jFirst = 1;
215             jLast  = n;
216             jInc   = 1;
217         } else {
218             jFirst = n;
219             jLast  = 1;
220             jInc   = -1;
221         }
222         jEnd = jLast + jInc;
223 
224         if (tScale!=One) {
225             grow = Zero;
226         } else {
227             if (A.diag()==NonUnit) {
228 //
229 //              A is non-unit triangular.
230 //
231 //              Compute GROW = 1/G(j) and XBND = 1/M(j).
232 //              Initially, M(0) = max{x(i), i=1,...,n}.
233 //
234                 grow = One / max(xBound, smallNum);
235                 xBound = grow;
236 
237                 bool grothFactorTooSmall = false;
238                 for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
239 //
240 //                  Exit the loop if the growth factor is too small.
241 //
242                     if (grow<=smallNum) {
243                         grothFactorTooSmall = false;
244                         break;
245                     }
246 //
247 //                  G(j) = max( G(j-1), M(j-1)*( 1 + CNORM(j) ) )
248 //
249                     const T xj = One + cNorm(j);
250                     grow = min(grow, xBound / xj);
251 //
252 //                  M(j) = M(j-1)*( 1 + CNORM(j) ) / abs(A(j,j))
253 //
254                     const T Tjj = abs(A(j,j));
255                     if (xj>Tjj) {
256                         xBound *= Tjj/xj;
257                     }
258                 }
259                 if (!grothFactorTooSmall) {
260                     grow = min(grow, xBound);
261                 }
262             } else {
263 //
264 //              A is unit triangular.
265 //
266 //              Compute GROW = 1/G(j), where G(0) = max{x(i), i=1,...,n}.
267 //
268                 grow = min(One, One / max(xBound, smallNum));
269                 for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
270 //
271 //                  Exit the loop if the growth factor is too small.
272 //
273                     if (grow<=smallNum) {
274                         break;
275                     }
276 //
277 //                  G(j) = ( 1 + CNORM(j) )*G(j-1)
278 //
279                     const T xj = One + cNorm(j);
280                     grow /= xj;
281                 }
282             }
283         }
284     }
285 
286     if ((grow*tScale)>smallNum) {
287 //
288 //      Use the Level 2 BLAS solve if the reciprocal of the bound on
289 //      elements of X is not too small.
290 //
291         blas::sv(trans, A, x);
292     } else {
293 //
294 //      Use a Level 1 BLAS solve, scaling intermediate results.
295 //
296         if (xMax>bigNum) {
297 //
298 //          Scale X so that its components are less than or equal to
299 //          BIGNUM in absolute value.
300 //
301             scale = bigNum /xMax;
302             x *= scale;
303             xMax = bigNum;
304         }
305 
306         if (trans==NoTrans) {
307 //
308 //          Solve A * x = b
309 //
310             for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
311 //
312 //              Compute x(j) = b(j) / A(j,j), scaling x if necessary.
313 //
314                 T xj = abs(x(j));
315                 T TjjS;
316 
317                 bool skip = false;
318 
319                 if (A.diag()==NonUnit) {
320                     TjjS = A(j,j) * tScale;
321                 } else {
322                     TjjS = tScale;
323                     if (tScale==One) {
324                         skip = true;
325                     }
326                 }
327                 if (!skip) {
328                     const T  Tjj = abs(TjjS);
329                     if (Tjj>smallNum) {
330 //
331 //                      abs(A(j,j)) > SMLNUM:
332 //
333                         if (Tjj<One) {
334                             if (xj>Tjj*bigNum) {
335 //
336 //                              Scale x by 1/b(j).
337 //
338                                 const T rec = One / xj;
339                                 x *= rec;
340                                 scale *= rec;
341                                 xMax *= rec;
342                             }
343                         }
344                         x(j) /= TjjS;
345                         xj = abs(x(j));
346                     } else if (Tjj>Zero) {
347 //
348 //                      0 < abs(A(j,j)) <= SMLNUM:
349 //
350                         if (xj>Tjj*bigNum) {
351 //
352 //                          Scale x by (1/abs(x(j)))*abs(A(j,j))*BIGNUM
353 //                          to avoid overflow when dividing by A(j,j).
354 //
355                             T rec = (Tjj*bigNum) / xj;
356                             if (cNorm(j)>One) {
357 //
358 //                              Scale by 1/CNORM(j) to avoid overflow when
359 //                              multiplying x(j) times column j.
360 //
361                                 rec /= cNorm(j);
362                             }
363                             x *= rec;
364                             scale *= rec;
365                             xMax *= rec;
366                         }
367                         x(j) /= TjjS;
368                         xj = abs(x(j));
369                     } else {
370 //
371 //                      A(j,j) = 0:  Set x(1:n) = 0, x(j) = 1, and
372 //                      scale = 0, and compute a solution to A*x = 0.
373 //
374                         x = Zero;
375                         x(j) = One;
376                         xj = One;
377                         scale = Zero;
378                         xMax = Zero;
379                     }
380                 }
381 //
382 //              Scale x if necessary to avoid overflow when adding a
383 //              multiple of column j of A.
384 //
385                 if (xj>One) {
386                     T rec = One / xj;
387                     if (cNorm(j)>(bigNum-xMax)*rec) {
388 //
389 //                      Scale x by 1/(2*abs(x(j))).
390 //
391                         rec *= Half;
392                         x *= rec;
393                         scale *= rec;
394                     }
395                 } else if (xj*cNorm(j)>(bigNum-xMax)) {
396 //
397 //                  Scale x by 1/2.
398 //
399                     x *= Half;
400                     scale *= Half;
401                 }
402 
403                 if (upper) {
404                     if (j>1) {
405 //
406 //                      Compute the update
407 //                      x(1:j-1) := x(1:j-1) - x(j) * A(1:j-1,j)
408 //
409                         x(_(1,j-1)) -= (x(j)*tScale) * A(_(1,j-1),j);
410                         const IndexType i = blas::iamax(x(_(1,j-1)));
411                         xMax = abs(x(i));
412                     }
413                 } else {
414                     if (j<n) {
415 //
416 //                      Compute the update
417 //                      x(j+1:n) := x(j+1:n) - x(j) * A(j+1:n,j)
418 //
419                         x(_(j+1,n)) -= (x(j)*tScale) * A(_(j+1,n),j);
420                         const IndexType i = j + blas::iamax(x(_(j+1,n)));
421                         xMax = abs(x(i));
422                     }
423                 }
424             }
425 
426         } else {
427 //
428 //          Solve A**T * x = b
429 //
430             for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
431 //
432 //              Compute x(j) = b(j) - sum A(k,j)*x(k).
433 //                                    k<>j
434 //
435                 T xj = abs(x(j));
436                 T uScale = tScale;
437                 T rec = One / max(xMax, One);
438                 T TjjS = Zero;
439 
440                 if (cNorm(j)>(bigNum-xj)*rec) {
441 //
442 //                  If x(j) could overflow, scale x by 1/(2*XMAX).
443 //
444                     rec *= Half;
445                     if (A.diag()==NonUnit) {
446                         TjjS = A(j,j)*tScale;
447                     } else {
448                         TjjS = tScale;
449                     }
450                     const T  Tjj = abs(TjjS);
451                     if (Tjj>One) {
452 //
453 //                      Divide by A(j,j) when scaling x if A(j,j) > 1.
454 //
455                         rec = min(One, rec*Tjj);
456                         uScale /= TjjS;
457                     }
458                     if (rec<One) {
459                         x *= rec;
460                         scale *= rec;
461                         xMax *= rec;
462                     }
463                 }
464 
465                 T sumJ = Zero;
466                 if (uScale==One) {
467 //
468 //                  If the scaling needed for A in the dot product is 1,
469 //                  call DDOT to perform the dot product.
470 //
471                     if (upper) {
472                         sumJ = A(_(1,j-1),j) * x(_(1,j-1));
473                     } else if (j<n) {
474                         sumJ = A(_(j+1,n),j) * x(_(j+1,n));
475                     }
476                 } else {
477 //
478 //                  Otherwise, use in-line code for the dot product.
479 //
480                     if (upper) {
481                         for (IndexType i=1; i<=j-1; ++i) {
482                             sumJ += (A(i,j)*uScale)*x(i);
483                         }
484                     } else if (j<n) {
485                         for (IndexType i=j+1; i<=n; ++i) {
486                             sumJ += (A(i,j)*uScale)*x(i);
487                         }
488                     }
489                 }
490                 if (uScale==tScale) {
491 //
492 //                  Compute x(j) := ( x(j) - sumj ) / A(j,j) if 1/A(j,j)
493 //                  was not used to scale the dotproduct.
494 //
495                     x(j) -= sumJ;
496                     xj = abs(x(j));
497 
498                     bool skip = false;
499                     if (A.diag()==NonUnit) {
500                         TjjS = A(j,j)*tScale;
501                     } else {
502                         TjjS = tScale;
503                         if (tScale==One) {
504                             skip = true;
505                         }
506                     }
507                     if (!skip) {
508 //
509 //                      Compute x(j) = x(j) / A(j,j), scaling if necessary.
510 //
511                         const T  Tjj = abs(TjjS);
512                         if (Tjj>smallNum) {
513 //
514 //                          abs(A(j,j)) > SMLNUM:
515 //
516                             if (Tjj<One) {
517                                 if (xj>Tjj*bigNum) {
518 //
519 //                                  Scale X by 1/abs(x(j)).
520 //
521                                     rec = One / xj;
522                                     x *= rec;
523                                     scale *= rec;
524                                     xMax *= rec;
525                                 }
526                             }
527                             x(j) /= TjjS;
528                         } else if (Tjj>Zero) {
529 //
530 //                          0 < abs(A(j,j)) <= SMLNUM:
531 //
532                             if (xj>Tjj*bigNum) {
533 //
534 //                              Scale x by (1/abs(x(j)))*abs(A(j,j))*BIGNUM.
535 //
536                                 const T rec = (Tjj*bigNum) / xj;
537                                 x *= rec;
538                                 scale *= rec;
539                                 xMax *= rec;
540                             }
541                             x(j) /= TjjS;
542                         } else {
543 //
544 //                          A(j,j) = 0:  Set x(1:n) = 0, x(j) = 1, and
545 //                          scale = 0, and compute a solution to A**T*x = 0.
546 //
547                             x = Zero;
548                             x(j) = One;
549                             scale = Zero;
550                             xMax = Zero;
551                         }
552                     }
553                 } else {
554 //
555 //                  Compute x(j) := x(j) / A(j,j)  - sumj if the dot
556 //                  product has already been divided by 1/A(j,j).
557 //
558                     x(j) = x(j)/TjjS - sumJ;
559                 }
560                 xMax = max(xMax, abs(x(j)));
561             }
562         }
563         scale /= tScale;
564     }
565 //
566 //  Scale the column norms by 1/TSCAL for return.
567 //
568     if (tScale!=One) {
569         cNorm *= One/tScale;
570     }
571 }
572 
573 //== interface for native lapack ===============================================
574 
575 #ifdef CHECK_CXXLAPACK
576 
577 template <typename MA, typename VX, typename SCALE, typename CNORM>
578 void
579 latrs_native(Transpose             trans,
580              bool                  normIn,
581              const TrMatrix<MA>    &A,
582              DenseVector<VX>       &x,
583              SCALE                 &scale,
584              DenseVector<CNORM>    &cNorm)
585 {
586     typedef typename TrMatrix<MA>::ElementType  T;
587 
588     const char       UPLO   = cxxblas::getF77BlasChar(A.upLo());
589     const char       TRANS  = cxxblas::getF77BlasChar(trans);
590     const char       DIAG   = cxxblas::getF77BlasChar(A.diag());
591     const char       NORMIN = (normIn) ? 'Y' : 'N';
592     const INTEGER    N      = A.dim();
593     const INTEGER    LDA    = A.leadingDimension();
594     INTEGER          INFO;
595 
596     if (IsSame<T,double>::value) {
597         LAPACK_IMPL(dlatrs)(&UPLO,
598                             &TRANS,
599                             &DIAG,
600                             &NORMIN,
601                             &N,
602                             A.data(),
603                             &LDA,
604                             x.data(),
605                             &scale,
606                             cNorm.data(),
607                             &INFO);
608     } else {
609         ASSERT(0);
610     }
611     ASSERT(INFO==0);
612 }
613 
614 #endif // CHECK_CXXLAPACK
615 
616 //== public interface ==========================================================
617 
618 template <typename MA, typename VX, typename SCALE, typename CNORM>
619 void
620 latrs(Transpose             trans,
621       bool                  normIn,
622       const TrMatrix<MA>    &A,
623       DenseVector<VX>       &x,
624       SCALE                 &scale,
625       DenseVector<CNORM>    &cNorm)
626 {
627 //
628 //  Test the input parameters
629 //
630 #   ifndef NDEBUG
631     ASSERT(A.firstRow()==1);
632     ASSERT(A.firstCol()==1);
633     ASSERT(x.firstIndex()==1);
634     ASSERT(x.length()==A.dim());
635     ASSERT(cNorm.firstIndex()==1);
636     ASSERT(cNorm.length()==A.dim());
637 #   endif
638 
639 #   ifdef CHECK_CXXLAPACK
640 //
641 //  Make copies of output arguments
642 //
643     typename DenseVector<VX>::NoView     x_org     = x;
644     SCALE                                scale_org = scale;
645     typename DenseVector<CNORM>::NoView  cNorm_org = cNorm;
646 #   endif
647 
648 //
649 //  Call implementation
650 //
651     latrs_generic(trans, normIn, A, x, scale, cNorm);
652     //latrs_native(trans, normIn, A, x, scale, cNorm);
653 
654 #   ifdef CHECK_CXXLAPACK
655 //
656 //  Compare results
657 //
658     typename DenseVector<VX>::NoView     x_generic     = x;
659     SCALE                                scale_generic = scale;
660     typename DenseVector<CNORM>::NoView  cNorm_generic = cNorm;
661 
662     x     = x_org;
663     scale = scale_org;
664     cNorm = cNorm_org;
665 
666     latrs_native(trans, normIn, A, x, scale, cNorm);
667 
668     bool failed = false;
669     if (! isIdentical(x_generic, x, "x_generic""x")) {
670         std::cerr << "CXXLAPACK: x_generic = " << x_generic << std::endl;
671         std::cerr << "F77LAPACK: x = " << x << std::endl;
672         failed = true;
673     }
674 
675     if (! isIdentical(scale_generic, scale, "scale_generic""scale")) {
676         std::cerr << "CXXLAPACK: scale_generic = "
677                   << scale_generic << std::endl;
678         std::cerr << "F77LAPACK: scale = " << scale << std::endl;
679         failed = true;
680     }
681 
682     if (! isIdentical(cNorm_generic, cNorm, "cNorm_generic""cNorm")) {
683         std::cerr << "CXXLAPACK: cNorm_generic = "
684                   << cNorm_generic << std::endl;
685         std::cerr << "F77LAPACK: cNorm = " << cNorm << std::endl;
686         failed = true;
687     }
688 
689 
690     if (failed) {
691         std::cerr << "x_org = " << x_org << std::endl;
692         std::cerr << "scale_org = " << scale_org << std::endl;
693         ASSERT(0);
694     } else {
695         // std::cerr << "passed: latrs" << std::endl;
696     }
697 #   endif
698 }
699 
700 //-- forwarding ----------------------------------------------------------------
701 template <typename MA, typename VX, typename SCALE, typename CNORM>
702 void
703 latrs(Transpose  trans,
704       bool       normIn,
705       const MA   &A,
706       VX         &&x,
707       SCALE      &&scale,
708       CNORM      &&cNorm)
709 {
710     CHECKPOINT_ENTER;
711     latrs(trans, normIn, A, x, scale, cNorm);
712     CHECKPOINT_LEAVE;
713 }
714 
715 } } // namespace lapack, flens
716 
717 #endif // FLENS_LAPACK_AUX_LATRS_TCC