1 /*
  2  *   Copyright (c) 2011, Michael Lehn
  3  *
  4  *   All rights reserved.
  5  *
  6  *   Redistribution and use in source and binary forms, with or without
  7  *   modification, are permitted provided that the following conditions
  8  *   are met:
  9  *
 10  *   1) Redistributions of source code must retain the above copyright
 11  *      notice, this list of conditions and the following disclaimer.
 12  *   2) Redistributions in binary form must reproduce the above copyright
 13  *      notice, this list of conditions and the following disclaimer in
 14  *      the documentation and/or other materials provided with the
 15  *      distribution.
 16  *   3) Neither the name of the FLENS development group nor the names of
 17  *      its contributors may be used to endorse or promote products derived
 18  *      from this software without specific prior written permission.
 19  *
 20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 21  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 22  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 23  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 24  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 26  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 27  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 28  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 29  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 /* Based on
 34  *
 35        SUBROUTINE DLALN2( LTRANS, NA, NW, SMIN, CA, A, LDA, D1, D2, B,
 36       $                   LDB, WR, WI, X, LDX, SCALE, XNORM, INFO )
 37  *
 38  *  -- LAPACK auxiliary routine (version 3.2) --
 39  *  -- LAPACK is a software package provided by Univ. of Tennessee,    --
 40  *  -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
 41  *     November 2006
 42  *
 43  */
 44 
 45 #ifndef FLENS_LAPACK_EIG_LALN2_TCC
 46 #define FLENS_LAPACK_EIG_LALN2_TCC 1
 47 
 48 #include <flens/blas/blas.h>
 49 #include <flens/lapack/lapack.h>
 50 
 51 namespace flens { namespace lapack {
 52 
 53 //== generic lapack implementation =============================================
 54 
 55 template <typename NW, typename SAFEMIN, typename CA, typename MA,
 56           typename D1, typename D2, typename MB, typename WR, typename WI,
 57           typename MX, typename SCALE, typename XNORM>
 58 typename GeMatrix<MX>::IndexType
 59 laln2_generic(bool                  transA,
 60               NW                    nw,
 61               const SAFEMIN         &safeMin,
 62               const CA              &ca,
 63               const GeMatrix<MA>    &A,
 64               const D1              &d1,
 65               const D2              &d2,
 66               const GeMatrix<MB>    &B,
 67               const WR              &wr,
 68               const WI              &wi,
 69               GeMatrix<MX>          &X,
 70               SCALE                 &scale,
 71               XNORM                 &xNorm)
 72 {
 73     using std::abs;
 74     using flens::max;
 75 
 76     typedef typename GeMatrix<MX>::ElementType  T;
 77     typedef typename GeMatrix<MX>::IndexType    IndexType;
 78 
 79     const T Zero(0), One(1), Two(2);
 80 
 81     const IndexType na = A.numRows();
 82 
 83     IndexType info = 0;
 84 
 85 //
 86 //    .. Local Arrays ..
 87 //
 88     T   _ciData[4], _crData[4];
 89     GeMatrixView<T> CI = typename GeMatrixView<T>::Engine(22, _ciData, 2);
 90     GeMatrixView<T> CR = typename GeMatrixView<T>::Engine(22, _crData, 2);
 91 
 92     DenseVectorView<T> civ = typename DenseVectorView<T>::Engine(4, _ciData);
 93     DenseVectorView<T> crv = typename DenseVectorView<T>::Engine(4, _crData);
 94 
 95     bool _rSwapData[4] = {falsetruefalsetrue};
 96     bool _zSwapData[4] = {falsefalsetruetrue};
 97     DenseVectorView<bool>
 98         rSwap = typename DenseVectorView<bool>::Engine(4, _rSwapData),
 99         zSwap = typename DenseVectorView<bool>::Engine(4, _zSwapData);
100 
101     IndexType _iPivotData[16] = { 1234,
102                                   2143,
103                                   3412,
104                                   4321};
105     GeMatrixView<IndexType>
106         iPivot = typename GeMatrixView<IndexType>::Engine(44, _iPivotData, 4);
107 
108 //
109 //  Compute BIGNUM
110 //
111     const T smallNum = Two*lamch<T>(SafeMin);
112     const T bigNum = One / smallNum;
113     const T safeMini = max(safeMin, smallNum);
114 //
115 //  Standard Initializations
116 //
117     scale = One;
118 
119     if (na==1) {
120 //
121 //      1 x 1  (i.e., scalar) system   C X = B
122 //
123         if (nw==1) {
124 //
125 //          Real 1x1 system.
126 //
127 //          C = ca A - w D
128 //
129             T csr   = ca*A(1,1) - wr*d1;
130             T cNorm = abs(csr);
131 //
132 //          If | C | < SMINI, use C = SMINI
133 //
134             if (cNorm<safeMini) {
135                 csr     = safeMini;
136                 cNorm   = safeMin;
137                 info    = 1;
138             }
139 //
140 //          Check scaling for  X = B / C
141 //
142             T bNorm = abs(B(1,1));
143             if (cNorm<One && bNorm>One) {
144                 if (bNorm>bigNum*cNorm) {
145                     scale = One / bNorm;
146                 }
147             }
148 //
149 //          Compute X
150 //
151             X(1,1) = (B(1,1)*scale) / csr;
152             xNorm = abs(X(1,1));
153         } else {
154 //
155 //          Complex 1x1 system (w is complex)
156 //
157 //          C = ca A - w D
158 //
159             T csr = ca*A(1,1) - wr*d1;
160             T csi = -wi*d1;
161             T cNorm = abs(csr) + abs(csi);
162 //
163 //          If | C | < SMINI, use C = SMINI
164 //
165             if (cNorm<safeMini) {
166                 csr     = safeMini;
167                 csi     = Zero;
168                 cNorm   = safeMini;
169                 info    = 1;
170             }
171 //
172 //          Check scaling for  X = B / C
173 //
174             T bNorm = abs(B(1,1)) + abs(B(1,2));
175             if (cNorm<One && bNorm>One) {
176                 if (bNorm>bigNum*cNorm) {
177                     scale = One / bNorm;
178                 }
179             }
180 //
181 //          Compute X
182 //
183             ladiv(scale*B(1,1), scale*B(1,2), csr, csi, X(1,1), X(1,2));
184             xNorm = abs(X(1,1)) + abs(X(1,2));
185         }
186 
187     } else {
188 //
189 //      2x2 System
190 //
191 //      Compute the real part of  C = ca A - w D  (or  ca A**T - w D )
192 //
193         CR(1,1) = ca*A(1,1) - wr*d1;
194         CR(2,2) = ca*A(2,2) - wr*d2;
195         if (transA) {
196             CR(1,2) = ca*A(2,1);
197             CR(2,1) = ca*A(1,2);
198         } else {
199             CR(2,1) = ca*A(2,1);
200             CR(1,2) = ca*A(1,2);
201         }
202 
203         if (nw==1) {
204 //
205 //          Real 2x2 system  (w is real)
206 //
207 //          Find the largest element in C
208 //
209             T cMax = Zero;
210             IndexType icMax = 0;
211 
212             for (IndexType j=1; j<=4; ++j) {
213                 if (abs(crv(j))>cMax) {
214                     cMax = abs(crv(j));
215                     icMax = j;
216                 }
217             }
218 //
219 //          If norm(C) < SMINI, use SMINI*identity.
220 //
221             if (cMax<safeMini) {
222                 const T bNorm = max(abs(B(1,1)), abs(B(2,1)));
223                 if (safeMini<One && bNorm>One) {
224                     if (bNorm>bigNum*safeMini) {
225                         scale = One/bNorm;
226                     }
227                 }
228                 const T temp = scale / safeMini;
229                 X(1,1) = temp*B(1,1);
230                 X(2,1) = temp*B(2,1);
231                 xNorm  = temp*bNorm;
232                 info   = 1;
233                 return info;
234             }
235 //
236 //          Gaussian elimination with complete pivoting.
237 //
238             T UR11  = crv(icMax);
239             T CR21  = crv(iPivot(2,icMax));
240             T UR12  = crv(iPivot(3,icMax));
241             T CR22  = crv(iPivot(4,icMax));
242             T UR11R = One / UR11;
243             T LR21  = UR11R*CR21;
244             T UR22  = CR22 - UR12*LR21;
245 //
246 //          If smaller pivot < SMINI, use SMINI
247 //
248             if (abs(UR22)<safeMini) {
249                 UR22 = safeMini;
250                 info = 1;
251             }
252             T BR1, BR2;
253             if (rSwap(icMax)) {
254                 BR1 = B(2,1);
255                 BR2 = B(1,1);
256             } else {
257                 BR1 = B(1,1);
258                 BR2 = B(2,1);
259             }
260             BR2 = BR2 - LR21*BR1;
261 
262             const T BBND = max(abs(BR1*(UR22*UR11R)), abs(BR2));
263             if (BBND>One && abs(UR22)<One) {
264                 if (BBND>=bigNum*abs(UR22)) {
265                     scale = One / BBND;
266                 }
267             }
268 
269             const T XR2 = (BR2*scale) / UR22;
270             const T XR1 = (scale*BR1)*UR11R - XR2*(UR11R*UR12);
271             if (zSwap(icMax)) {
272                 X(1,1) = XR2;
273                 X(2,1) = XR1;
274             } else {
275                 X(1,1) = XR1;
276                 X(2,1) = XR2;
277             }
278             xNorm = max(abs(XR1), abs(XR2));
279 //
280 //          Further scaling if  norm(A) norm(X) > overflow
281 //
282             if (xNorm>One && cMax>One) {
283                 if (xNorm>bigNum/cMax) {
284                     const T temp = cMax / bigNum;
285                     X(1,1) = temp*X(1,1);
286                     X(2,1) = temp*X(2,1);
287                     xNorm = temp*xNorm;
288                     scale = temp*scale;
289                 }
290             }
291         } else {
292 //
293 //          Complex 2x2 system  (w is complex)
294 //
295 //          Find the largest element in C
296 //
297             CI(1,1) = -wi*d1;
298             CI(2,1) = Zero;
299             CI(1,2) = Zero;
300             CI(2,2) = -wi*d2;
301 
302             T cMax  = Zero;
303             IndexType icMax   = 0;
304 
305             for (IndexType j=1; j<=4; ++j) {
306                 if (abs(crv(j))+abs(civ(j))>cMax) {
307                     cMax = abs(crv(j)) + abs(civ(j));
308                     icMax = j;
309                 }
310             }
311 //
312 //          If norm(C) < SMINI, use SMINI*identity.
313 //
314             if (cMax<safeMini) {
315                 const T bNorm = max(abs(B(1,1))+abs(B(1,2)), abs(B(2,1))+abs(B(2,2)));
316                 if (safeMini<One && bNorm>One) {
317                     if (bNorm>bigNum*safeMini) {
318                         scale = One / bNorm;
319                     }
320                 }
321                 const T temp = scale / safeMini;
322                 X(1,1) = temp*B(1,1);
323                 X(2,1) = temp*B(2,1);
324                 X(1,2) = temp*B(1,2);
325                 X(2,2) = temp*B(2,2);
326                 xNorm = temp*bNorm;
327                 info = 1;
328                 return info;
329             }
330 //
331 //          Gaussian elimination with complete pivoting.
332 //
333             const T UR11 = crv(icMax);
334             const T UI11 = civ(icMax);
335             const T CR21 = crv(iPivot(2,icMax));
336             const T CI21 = civ(iPivot(2,icMax));
337             const T UR12 = crv(iPivot(3,icMax));
338             const T UI12 = civ(iPivot(3,icMax));
339             const T CR22 = crv(iPivot(4,icMax));
340             const T CI22 = civ(iPivot(4,icMax));
341 
342             T UR11R, UI11R, LR21, LI21, UR12S, UI12S, UR22, UI22;
343 
344             if (icMax==1 || icMax==4) {
345 //
346 //              Code when off-diagonals of pivoted C are real
347 //
348                 if (abs(UR11)>abs(UI11)) {
349                     const T temp = UI11 / UR11;
350                     UR11R = One / (UR11*(One+pow(temp,2)));
351                     UI11R = -temp*UR11R;
352                 } else {
353                     const T temp = UR11 / UI11;
354                     UI11R = -One / (UI11*(One+pow(temp,2)));
355                     UR11R = -temp*UI11R;
356                 }
357                 LR21 = CR21*UR11R;
358                 LI21 = CR21*UI11R;
359                 UR12S = UR12*UR11R;
360                 UI12S = UR12*UI11R;
361                 UR22 = CR22 - UR12*LR21;
362                 UI22 = CI22 - UR12*LI21;
363             } else {
364 //
365 //              Code when diagonals of pivoted C are real
366 //
367                 UR11R = One / UR11;
368                 UI11R = Zero;
369                 LR21 = CR21*UR11R;
370                 LI21 = CI21*UR11R;
371                 UR12S = UR12*UR11R;
372                 UI12S = UI12*UR11R;
373                 UR22 = CR22 - UR12*LR21 + UI12*LI21;
374                 UI22 = -UR12*LI21 - UI12*LR21;
375             }
376             const T U22ABS = abs(UR22) + abs(UI22);
377 //
378 //          If smaller pivot < SMINI, use SMINI
379 //
380             T BR1, BR2, BI1, BI2;
381 
382             if (U22ABS<safeMini) {
383                 UR22 = safeMini;
384                 UI22 = Zero;
385                 info = 1;
386             }
387             if (rSwap(icMax)) {
388                 BR2 = B(1,1);
389                 BR1 = B(2,1);
390                 BI2 = B(1,2);
391                 BI1 = B(2,2);
392             } else {
393                 BR1 = B(1,1);
394                 BR2 = B(2,1);
395                 BI1 = B(1,2);
396                 BI2 = B(2,2);
397             }
398             BR2 = BR2 - LR21*BR1 + LI21*BI1;
399             BI2 = BI2 - LI21*BR1 - LR21*BI1;
400             const T BBND = max((abs(BR1)+abs(BI1))
401                                *(U22ABS*(abs(UR11R)+abs(UI11R))),
402                               abs(BR2)+abs(BI2));
403             if (BBND>One && U22ABS<One) {
404                 if (BBND>=bigNum*U22ABS) {
405                     scale = One / BBND;
406                     BR1 = scale*BR1;
407                     BI1 = scale*BI1;
408                     BR2 = scale*BR2;
409                     BI2 = scale*BI2;
410                 }
411             }
412 
413             T XR1, XR2, XI1, XI2;
414             ladiv(BR2, BI2, UR22, UI22, XR2, XI2);
415             XR1 = UR11R*BR1 - UI11R*BI1 - UR12S*XR2 + UI12S*XI2;
416             XI1 = UI11R*BR1 + UR11R*BI1 - UI12S*XR2 - UR12S*XI2;
417             if (zSwap(icMax)) {
418                 X(1,1) = XR2;
419                 X(2,1) = XR1;
420                 X(1,2) = XI2;
421                 X(2,2) = XI1;
422             } else {
423                 X(1,1) = XR1;
424                 X(2,1) = XR2;
425                 X(1,2) = XI1;
426                 X(2,2) = XI2;
427             }
428             xNorm = max(abs(XR1)+abs(XI1), abs(XR2)+abs(XI2));
429 //
430 //          Further scaling if  norm(A) norm(X) > overflow
431 //
432             if (xNorm>One && cMax>One) {
433                 if (xNorm>bigNum/cMax) {
434                     const T temp = cMax / bigNum;
435                     X(1,1) = temp*X(1,1);
436                     X(2,1) = temp*X(2,1);
437                     X(1,2) = temp*X(1,2);
438                     X(2,2) = temp*X(2,2);
439                     xNorm *= temp;
440                     scale *= temp;
441                 }
442             }
443         }
444     }
445 
446     return info;
447 }
448 
449 
450 //== interface for native lapack ===============================================
451 
452 #ifdef CHECK_CXXLAPACK
453 
454 template <typename NW, typename SAFEMIN, typename CA, typename MA,
455           typename D1, typename D2, typename MB, typename WR, typename WI,
456           typename MX, typename SCALE, typename XNORM>
457 typename GeMatrix<MX>::IndexType
458 laln2_native(bool                  transA,
459              NW                    nw,
460              const SAFEMIN         &safeMin,
461              const CA              &ca,
462              const GeMatrix<MA>    &A,
463              const D1              &d1,
464              const D2              &d2,
465              const GeMatrix<MB>    &B,
466              const WR              &wr,
467              const WI              &wi,
468              GeMatrix<MX>          &X,
469              SCALE                 &scale,
470              XNORM                 &xNorm)
471 {
472     typedef typename GeMatrix<MX>::ElementType  T;
473 
474     const LOGICAL    LTRANS = transA;
475     const INTEGER    NA     = A.numRows();
476     const INTEGER    _NW    = nw;
477     const T          SMIN   = safeMin;
478     const T          _CA    = ca;
479     const INTEGER    LDA    = A.leadingDimension();
480     const T          _D1    = d1;
481     const T          _D2    = d2;
482     const INTEGER    LDB    = B.leadingDimension();
483     const T          _WR    = wr;
484     const T          _WI    = wi;
485     const INTEGER    LDX    = X.leadingDimension();
486     T                _SCALE = scale;
487     T                _XNORM = xNorm;
488     INTEGER          INFO;
489 
490     if (IsSame<T,DOUBLE>::value) {
491         LAPACK_IMPL(dlaln2)(&LTRANS,
492                             &NA,
493                             &_NW,
494                             &SMIN,
495                             &_CA,
496                             A.data(),
497                             &LDA,
498                             &_D1,
499                             &_D2,
500                             B.data(),
501                             &LDB,
502                             &_WR,
503                             &_WI,
504                             X.data(),
505                             &LDX,
506                             &_SCALE,
507                             &_XNORM,
508                             &INFO);
509     } else {
510         ASSERT(0);
511     }
512     scale = _SCALE;
513     xNorm = _XNORM;
514     ASSERT(INFO>=0);
515     return INFO;
516 }
517 
518 #endif // CHECK_CXXLAPACK
519 
520 //== public interface ==========================================================
521 
522 template <typename NW, typename SAFEMIN, typename CA, typename MA,
523           typename D1, typename D2, typename MB, typename WR, typename WI,
524           typename MX, typename SCALE, typename XNORM>
525 typename GeMatrix<MX>::IndexType
526 laln2(bool                  transA,
527       NW                    nw,
528       const SAFEMIN         &safeMin,
529       const CA              &ca,
530       const GeMatrix<MA>    &A,
531       const D1              &d1,
532       const D2              &d2,
533       const GeMatrix<MB>    &B,
534       const WR              &wr,
535       const WI              &wi,
536       GeMatrix<MX>          &X,
537       SCALE                 &scale,
538       XNORM                 &xNorm)
539 {
540     LAPACK_DEBUG_OUT("BEGIN: laln2");
541 
542     typedef typename GeMatrix<MX>::IndexType IndexType;
543 
544 //
545 //  Test the input parameters
546 //
547 #   ifndef NDEBUG
548     ASSERT(A.numRows()==A.numCols());
549     const IndexType na = A.numRows();
550     ASSERT(na==1 || na==2);
551 
552     ASSERT(nw==1 || nw==2);
553 
554     ASSERT(B.numRows()==na);
555     ASSERT(B.numCols()==nw);
556 
557     ASSERT(X.numRows()==na);
558     ASSERT(X.numCols()==nw);
559 #   endif
560 
561 #   ifdef CHECK_CXXLAPACK
562 //
563 //  Make copies of output arguments
564 //
565     typename GeMatrix<MX>::NoView   X_org     = X;
566     SCALE                           scale_org = scale;
567     XNORM                           xNorm_org = xNorm;
568 #   endif
569 
570 //
571 //  Call implementation
572 //
573     const IndexType info = laln2_generic(transA, nw, safeMin, ca,
574                                          A, d1, d2, B, wr, wi,
575                                          X, scale, xNorm);
576 
577 #   ifdef CHECK_CXXLAPACK
578 //
579 //  Make copies of results computed by the generic implementation
580 //
581     typename GeMatrix<MX>::NoView   X_generic     = X;
582     SCALE                           scale_generic = scale;
583     XNORM                           xNorm_generic = xNorm;
584 
585 //
586 //  restore output arguments
587 //
588     X     = X_org;
589     scale = scale_org;
590     xNorm = xNorm_org;
591 
592 //
593 //  Compare generic results with results from the native implementation
594 //
595     const IndexType _info = laln2_native(transA, nw, safeMin, ca,
596                                          A, d1, d2, B, wr, wi,
597                                          X, scale, xNorm);
598 
599     bool failed = false;
600     if (! isIdentical(X_generic, X, "X_generic""X")) {
601         std::cerr << "CXXLAPACK: X_generic = " << X_generic << std::endl;
602         std::cerr << "F77LAPACK: X = " << X << std::endl;
603         failed = true;
604     }
605 
606     if (! isIdentical(scale_generic, scale, "scale_generic""scale")) {
607         std::cerr << "CXXLAPACK: scale_generic = "
608                   << scale_generic << std::endl;
609         std::cerr << "F77LAPACK: scale = " << scale << std::endl;
610         failed = true;
611     }
612 
613     if (! isIdentical(xNorm_generic, xNorm, "xNorm_generic""xNorm")) {
614         std::cerr << "CXXLAPACK: xNorm_generic = "
615                   << xNorm_generic << std::endl;
616         std::cerr << "F77LAPACK: xNorm = " << xNorm << std::endl;
617         failed = true;
618     }
619 
620     if (! isIdentical(info, _info, "info""_info")) {
621         std::cerr << "CXXLAPACK: info = " << info << std::endl;
622         std::cerr << "F77LAPACK: _info = " << _info << std::endl;
623         failed = true;
624     }
625 
626     if (failed) {
627         std::cerr << "error in: laln2.tcc" << std::endl;
628         ASSERT(0);
629     } else {
630 //        std::cerr << "passed: laln2.tcc" << std::endl;
631     }
632 #   endif
633 
634     LAPACK_DEBUG_OUT("END: laln2");
635 
636     return info;
637 }
638 
639 
640 //-- forwarding ----------------------------------------------------------------
641 
642 template <typename NW, typename SAFEMIN, typename CA, typename MA,
643           typename D1, typename D2, typename MB, typename WR, typename WI,
644           typename MX, typename SCALE, typename XNORM>
645 typename MX::IndexType
646 laln2(bool                  transA,
647       NW                    nw,
648       const SAFEMIN         &safeMin,
649       const CA              &ca,
650       const MA              &A,
651       const D1              &d1,
652       const D2              &d2,
653       const MB              &B,
654       const WR              &wr,
655       const WI              &wi,
656       MX                    &&X,
657       SCALE                 &scale,
658       XNORM                 &xNorm)
659 {
660     return laln2(transA, nw, safeMin, ca, A, d1, d2, B, wr, wi,
661                  X, scale, xNorm);
662 }
663 
664 } } // namespace lapack, flens
665 
666 #endif // FLENS_LAPACK_EIG_LALN2_TCC