1 /*
  2  *   Copyright (c) 2011, Michael Lehn
  3  *
  4  *   All rights reserved.
  5  *
  6  *   Redistribution and use in source and binary forms, with or without
  7  *   modification, are permitted provided that the following conditions
  8  *   are met:
  9  *
 10  *   1) Redistributions of source code must retain the above copyright
 11  *      notice, this list of conditions and the following disclaimer.
 12  *   2) Redistributions in binary form must reproduce the above copyright
 13  *      notice, this list of conditions and the following disclaimer in
 14  *      the documentation and/or other materials provided with the
 15  *      distribution.
 16  *   3) Neither the name of the FLENS development group nor the names of
 17  *      its contributors may be used to endorse or promote products derived
 18  *      from this software without specific prior written permission.
 19  *
 20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 21  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 22  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 23  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 24  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 26  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 27  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 28  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 29  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 /* Based on
 34  *
 35        SUBROUTINE DLASY2( LTRANL, LTRANR, ISGN, N1, N2, TL, LDTL, TR,
 36       $                   LDTR, B, LDB, SCALE, X, LDX, XNORM, INFO )
 37  *
 38  *  -- LAPACK auxiliary routine (version 3.2) --
 39  *  -- LAPACK is a software package provided by Univ. of Tennessee,    --
 40  *  -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
 41  *     November 2006
 42  */
 43 
 44 #ifndef FLENS_LAPACK_EIG_LASY2_TCC
 45 #define FLENS_LAPACK_EIG_LASY2_TCC 1
 46 
 47 #include <flens/blas/blas.h>
 48 #include <flens/lapack/lapack.h>
 49 
 50 namespace flens { namespace lapack {
 51 
 52 //== generic lapack implementation =============================================
 53 
 54 template <typename SIGN, typename MTL, typename MTR, typename MB,
 55           typename SCALE, typename MX, typename XNORM>
 56 typename GeMatrix<MX>::IndexType
 57 lasy2_generic(bool                  transLeft,
 58               bool                  transRight,
 59               SIGN                  iSign,
 60               const GeMatrix<MTL>   &TL,
 61               const GeMatrix<MTR>   &TR,
 62               const GeMatrix<MB>    &B,
 63               SCALE                 &scale,
 64               GeMatrix<MX>          &X,
 65               XNORM                 &xNorm)
 66 {
 67     using std::abs;
 68     using flens::max;
 69     using std::swap;
 70 
 71     typedef typename GeMatrix<MX>::ElementType  ElementType;
 72     typedef typename GeMatrix<MX>::IndexType    IndexType;
 73 
 74     const ElementType Zero(0), Half(0.5), One(1), Two(2), Eight(8);
 75 
 76     const IndexType n1 = TL.numRows();
 77     const IndexType n2 = TR.numRows();
 78 
 79     const Underscore<IndexType> _;
 80 
 81     IndexType   info = 0;
 82     IndexType   iPiv;
 83 
 84     bool  bSwap, xSwap;
 85     ElementType safeMin, beta, gamma, tau1, U11, U12, U22, L21, temp;
 86 
 87 //
 88 //    .. Local Arrays ..
 89 //
 90     IndexType  _jPivData[4],
 91                _locU12Data[4] = { 3412},
 92                _locL21Data[4] = { 2143},
 93                _locU22Data[4] = { 4321};
 94     DenseVectorView<IndexType>
 95         jPiv   = typename DenseVectorView<IndexType>::Engine(4, _jPivData),
 96         locU12 = typename DenseVectorView<IndexType>::Engine(4, _locU12Data),
 97         locL21 = typename DenseVectorView<IndexType>::Engine(4, _locL21Data),
 98         locU22 = typename DenseVectorView<IndexType>::Engine(4, _locU22Data);
 99 
100     ElementType   _bTmpData[4], _tmpData[4], _x2Data[2];
101     DenseVectorView<ElementType>
102         bTmp = typename DenseVectorView<ElementType>::Engine(4, _bTmpData),
103         tmp  = typename DenseVectorView<ElementType>::Engine(4, _tmpData),
104         x2   = typename DenseVectorView<ElementType>::Engine(2, _x2Data);
105 
106     bool  _xSwapPivData[4] = {falsefalsetruetrue},
107           _bSwapPivData[4] = {falsetruefalsetrue};
108     DenseVectorView<bool>
109         xSwapPiv  = typename DenseVectorView<bool>::Engine(4, _xSwapPivData),
110         bSwapPiv  = typename DenseVectorView<bool>::Engine(4, _bSwapPivData);
111 
112     ElementType  _t16Data[16];
113     GeMatrixView<ElementType>
114         T16 = typename GeMatrixView<ElementType>::Engine(44, _t16Data, 4);
115 
116 //
117 //  Quick return if possible
118 //
119     if (n1==0 || n2==0) {
120         return info;
121     }
122 //
123 //  Set constants to control overflow
124 //
125     const ElementType eps = lamch<ElementType>(Precision);
126     const ElementType smallNum = lamch<ElementType>(SafeMin) / eps;
127     const ElementType sign = iSign;
128 
129     const IndexType k = n1 + n1 + n2 - 2;
130 
131     switch (k) {
132 //
133 //  1 by 1: TL11*X + SGN*X*TR11 = B11
134 //
135     case 1:
136         tau1 = TL(1,1) + sign*TR(1,1);
137         beta = abs(tau1);
138         if (beta<=smallNum) {
139             tau1 = smallNum;
140             beta = smallNum;
141             info = 1;
142         }
143 
144         scale = One;
145         gamma = abs(B(1,1));
146         if (smallNum*gamma>beta) {
147             scale = One/gamma;
148         }
149 
150         X(1,1) = (B(1,1)*scale) / tau1;
151         xNorm = abs(X(1,1));
152         return info;
153 
154     case 2:
155     case 3:
156         if (k==2) {
157 //
158 //          1 by 2:
159 //          TL11*[X11 X12] + ISGN*[X11 X12]*op[TR11 TR12]  = [B11 B12]
160 //                                            [TR21 TR22]
161 //
162             safeMin = max(eps*max(abs(TL(1,1)), abs(TR(1,1)),
163                                   abs(TR(1,2)), abs(TR(2,1)),
164                                   abs(TR(2,2))),
165                           smallNum);
166             tmp(1) = TL(1,1) + sign*TR(1,1);
167             tmp(4) = TL(1,1) + sign*TR(2,2);
168             if (transRight) {
169                 tmp(2) = sign*TR(2,1);
170                 tmp(3) = sign*TR(1,2);
171             } else {
172                 tmp(2) = sign*TR(1,2);
173                 tmp(3) = sign*TR(2,1);
174             }
175             bTmp(1) = B(1,1);
176             bTmp(2) = B(1,2);
177         } else {
178 //
179 //          2 by 1:
180 //          op[TL11 TL12]*[X11] + ISGN* [X11]*TR11  = [B11]
181 //            [TL21 TL22] [X21]         [X21]         [B21]
182 //
183             safeMin = max(eps*max(abs(TR(1,1)), abs(TL(1,1)),
184                                   abs(TL(1,2)), abs(TL(2,1)),
185                                   abs(TL(2,2))),
186                           smallNum);
187             tmp(1) = TL(1,1) + sign*TR(1,1);
188             tmp(4) = TL(2,2) + sign*TR(1,1);
189             if (transLeft) {
190                 tmp(2) = TL(1,2);
191                 tmp(3) = TL(2,1);
192             } else {
193                 tmp(2) = TL(2,1);
194                 tmp(3) = TL(1,2);
195             }
196             bTmp(1) = B(1,1);
197             bTmp(2) = B(2,1);
198         }
199 //
200 //      Solve 2 by 2 system using complete pivoting.
201 //      Set pivots less than SMIN to SMIN.
202 //
203         iPiv = blas::iamax(tmp);
204         U11 = tmp(iPiv);
205         if (abs(U11)<=safeMin) {
206             info = 1;
207             U11 = safeMin;
208         }
209         U12 = tmp(locU12(iPiv));
210         L21 = tmp(locL21(iPiv)) / U11;
211         U22 = tmp(locU22(iPiv)) - U12*L21;
212         xSwap = xSwapPiv(iPiv);
213         bSwap = bSwapPiv(iPiv);
214         if (abs(U22)<=safeMin) {
215             info = 1;
216             U22 = safeMin;
217         }
218         if (bSwap) {
219             temp = bTmp(2);
220             bTmp(2) = bTmp(1) - L21*temp;
221             bTmp(1) = temp;
222         } else {
223             bTmp(2) -= L21*bTmp(1);
224         }
225         scale = One;
226         if ((Two*smallNum)*abs(bTmp(2))>abs(U22)
227          || (Two*smallNum)*abs(bTmp(1))>abs(U11))
228         {
229             scale = Half / max(abs(bTmp(1)), abs(bTmp(2)));
230             bTmp(1) *= scale;
231             bTmp(2) *= scale;
232         }
233         x2(2) = bTmp(2)/U22;
234         x2(1) = bTmp(1)/U11 - (U12/U11)*x2(2);
235         if (xSwap) {
236             swap(x2(1), x2(2));
237         }
238         X(1,1) = x2(1);
239         if (n1==1) {
240             X(1,2) = x2(2);
241             xNorm = abs(X(1,1)) + abs(X(1,2));
242         } else {
243             X(2,1) = x2(2);
244             xNorm = max(abs(X(1,1)), abs(X(2,1)));
245         }
246         return info;
247 
248 //
249 //   2 by 2:
250 //   op[TL11 TL12]*[X11 X12] +ISGN* [X11 X12]*op[TR11 TR12] = [B11 B12]
251 //     [TL21 TL22] [X21 X22]        [X21 X22]   [TR21 TR22]   [B21 B22]
252 //
253 //   Solve equivalent 4 by 4 system using complete pivoting.
254 //   Set pivots less than SMIN to SMIN.
255 //
256     case 4:
257         safeMin = max(abs(TR(1,1)), abs(TR(1,2)), abs(TR(2,1)), abs(TR(2,2)));
258         safeMin = max(safeMin, abs(TL(1,1)), abs(TL(1,2)),
259                                abs(TL(2,1)), abs(TL(2,2)));
260         safeMin = max(eps*safeMin, smallNum);
261         bTmp(1) = Zero;
262         T16 = 0;
263         T16(1,1) = TL(1,1) + sign*TR(1,1);
264         T16(2,2) = TL(2,2) + sign*TR(1,1);
265         T16(3,3) = TL(1,1) + sign*TR(2,2);
266         T16(4,4) = TL(2,2) + sign*TR(2,2);
267         if (transLeft) {
268             T16(1,2) = TL(2,1);
269             T16(2,1) = TL(1,2);
270             T16(3,4) = TL(2,1);
271             T16(4,3) = TL(1,2);
272         } else {
273             T16(1,2) = TL(1,2);
274             T16(2,1) = TL(2,1);
275             T16(3,4) = TL(1,2);
276             T16(4,3) = TL(2,1);
277         }
278         if (transRight) {
279             T16(1,3) = sign*TR(1,2);
280             T16(2,4) = sign*TR(1,2);
281             T16(3,1) = sign*TR(2,1);
282             T16(4,2) = sign*TR(2,1);
283         } else {
284             T16(1,3) = sign*TR(2,1);
285             T16(2,4) = sign*TR(2,1);
286             T16(3,1) = sign*TR(1,2);
287             T16(4,2) = sign*TR(1,2);
288         }
289         bTmp(1) = B(1,1);
290         bTmp(2) = B(2,1);
291         bTmp(3) = B(1,2);
292         bTmp(4) = B(2,2);
293 //
294 //      Perform elimination
295 //
296         for (IndexType i=1; i<=3; ++i) {
297             ElementType xMax = Zero;
298             IndexType ipSv = -1, jpSv = -1;
299 
300             for (IndexType ip=i; ip<=4; ++ip) {
301                 for (IndexType jp=i; jp<=4; ++jp) {
302                     if (abs(T16(ip,jp))>=xMax) {
303                         xMax = abs(T16(ip,jp));
304                         ipSv = ip;
305                         jpSv = jp;
306                     }
307                 }
308             }
309             if (ipSv!=i) {
310                 blas::swap(T16(ipSv,_), T16(i,_));
311                 swap(bTmp(i), bTmp(ipSv));
312             }
313             if (jpSv!=i) {
314                 blas::swap(T16(_,jpSv), T16(_,i));
315             }
316             jPiv(i) = jpSv;
317             if (abs(T16(i,i))<safeMin) {
318                 info = 1;
319                 T16(i,i) = safeMin;
320             }
321             for (IndexType j=i+1; j<=4; ++j) {
322                 T16(j,i) /= T16(i,i);
323                 bTmp(j)  -= T16(j,i)*bTmp(i);
324                 for (IndexType k=i+1; k<=4; ++k) {
325                     T16(j,k) -= T16(j,i)*T16(i,k);
326                 }
327             }
328         }
329         if (abs(T16(4,4))<safeMin) {
330             T16(4,4) = safeMin;
331         }
332         scale = One;
333         if ((Eight*smallNum)*abs(bTmp(1))>abs(T16(1,1))
334          || (Eight*smallNum)*abs(bTmp(2))>abs(T16(2,2))
335          || (Eight*smallNum)*abs(bTmp(3))>abs(T16(3,3))
336          || (Eight*smallNum)*abs(bTmp(4))>abs(T16(4,4)))
337         {
338             scale = (One/Eight) / max(abs(bTmp(1)), abs(bTmp(2)),
339                                       abs(bTmp(3)), abs(bTmp(4)));
340             bTmp(1) *= scale;
341             bTmp(2) *= scale;
342             bTmp(3) *= scale;
343             bTmp(4) *= scale;
344         }
345         for (IndexType i=1; i<=4; ++i) {
346             IndexType k = 5 - i;
347             const ElementType temp = One/T16(k,k);
348             tmp(k) = bTmp(k)*temp;
349             for (IndexType j=k+1; j<=4; ++j) {
350                 tmp(k) -= (temp*T16(k,j))*tmp(j);
351             }
352         }
353         for (IndexType i=1; i<=3; ++i) {
354             if (jPiv(4-i)!=4-i) {
355                 swap(tmp(4-i), tmp(jPiv(4-i)));
356             }
357         }
358         X(1,1) = tmp(1);
359         X(2,1) = tmp(2);
360         X(1,2) = tmp(3);
361         X(2,2) = tmp(4);
362         xNorm = max(abs(tmp(1))+abs(tmp(3)), abs(tmp(2))+abs(tmp(4)));
363         return info;
364     }
365 
366     // error if switch does not handle all cases
367     ASSERT(0);
368     return info;
369 }
370 
371 //== interface for native lapack ===============================================
372 
373 #ifdef CHECK_CXXLAPACK
374 
375 template <typename SIGN, typename MTL, typename MTR, typename MB,
376           typename SCALE, typename MX, typename XNORM>
377 typename GeMatrix<MX>::IndexType
378 lasy2_native(bool                  transLeft,
379              bool                  transRight,
380              SIGN                  sign,
381              const GeMatrix<MTL>   &TL,
382              const GeMatrix<MTR>   &TR,
383              const GeMatrix<MB>    &B,
384              SCALE                 &scale,
385              GeMatrix<MX>          &X,
386              XNORM                 &xNorm)
387 {
388     typedef typename GeMatrix<MX>::ElementType ElementType;
389 
390     const LOGICAL    LTRANL     = transLeft;
391     const LOGICAL    LTRANR     = transRight;
392     const INTEGER    ISGN       = sign;
393     const INTEGER    N1         = TL.numRows();
394     const INTEGER    N2         = TR.numRows();
395     const INTEGER    LDTL       = TL.leadingDimension();
396     const INTEGER    LDTR       = TR.leadingDimension();
397     const INTEGER    LDB        = B.leadingDimension();
398     ElementType      _SCALE     = scale;
399     const INTEGER    LDX        = X.leadingDimension();
400     ElementType      _XNORM     = xNorm;
401     INTEGER          INFO;
402 
403     if (IsSame<ElementType,DOUBLE>::value) {
404         LAPACK_IMPL(dlasy2)(&LTRANL,
405                             &LTRANR,
406                             &ISGN,
407                             &N1,
408                             &N2,
409                             TL.data(),
410                             &LDTL,
411                             TR.data(),
412                             &LDTR,
413                             B.data(),
414                             &LDB,
415                             &_SCALE,
416                             X.data(),
417                             &LDX,
418                             &_XNORM,
419                             &INFO);
420     } else {
421         ASSERT(0);
422     }
423     ASSERT(INFO>=0);
424 
425     scale = _SCALE;
426     xNorm = _XNORM;
427 
428     return INFO;
429 }
430 
431 #endif // CHECK_CXXLAPACK
432 
433 //== public interface ==========================================================
434 
435 template <typename SIGN, typename MTL, typename MTR, typename MB,
436           typename SCALE, typename MX, typename XNORM>
437 typename GeMatrix<MX>::IndexType
438 lasy2(bool                  transLeft,
439       bool                  transRight,
440       SIGN                  sign,
441       const GeMatrix<MTL>   &TL,
442       const GeMatrix<MTR>   &TR,
443       const GeMatrix<MB>    &B,
444       SCALE                 &scale,
445       GeMatrix<MX>          &X,
446       XNORM                 &xNorm)
447 {
448     LAPACK_DEBUG_OUT("lasy2");
449 
450     typedef typename GeMatrix<MX>::IndexType IndexType;
451 //
452 //  Test the input parameters
453 //
454 #   ifndef NDEBUG
455     ASSERT(sign==1 || sign==-1);
456 
457     ASSERT(TL.firstRow()==1);
458     ASSERT(TL.firstCol()==1);
459     ASSERT(TL.numRows()==TL.numCols());
460     ASSERT(TL.numRows()<=2);
461 
462     ASSERT(TR.firstRow()==1);
463     ASSERT(TR.firstCol()==1);
464     ASSERT(TR.numRows()==TR.numCols());
465     ASSERT(TR.numRows()<=2);
466 
467     ASSERT(B.firstRow()==1);
468     ASSERT(B.firstCol()==1);
469     ASSERT(B.numRows()==TL.numRows());
470     ASSERT(B.numCols()==TR.numRows());
471 
472     ASSERT(X.firstRow()==1);
473     ASSERT(X.firstCol()==1);
474     ASSERT(X.numRows()==TL.numRows());
475     ASSERT(X.numCols()==TR.numRows());
476 #   endif
477 
478 #   ifdef CHECK_CXXLAPACK
479 //
480 //  Make copies of output arguments
481 //
482     SCALE                           scale_org  = scale;
483     typename GeMatrix<MX>::NoView   X_org      = X;
484     XNORM                           xNorm_org  = xNorm;
485 #   endif
486 
487 //
488 //  Call implementation
489 //
490     IndexType info = lasy2_generic(transLeft, transRight, sign,
491                                    TL, TR, B,
492                                    scale, X, xNorm);
493 #   ifdef CHECK_CXXLAPACK
494 //
495 //  Make copies of results computed by the generic implementation
496 //
497     SCALE                           scale_generic   = scale;
498     typename GeMatrix<MX>::NoView   X_generic       = X;
499     XNORM                           xNorm_generic   = xNorm;
500 
501 //
502 //  restore output arguments
503 //
504     scale   = scale_org;
505     X       = X_org;
506     xNorm   = xNorm_org;
507 
508 //
509 //  Compare generic results with results from the native implementation
510 //
511 
512     IndexType _info = lasy2_native(transLeft, transRight, sign,
513                                    TL, TR, B,
514                                    scale, X, xNorm);
515 
516     bool failed = false;
517     if (! isIdentical(scale_generic, scale, "scale_generic""scale")) {
518         std::cerr << "CXXLAPACK: scale_generic = "
519                   << scale_generic << std::endl;
520         std::cerr << "F77LAPACK: scale = " << scale << std::endl;
521         failed = true;
522     }
523     if (! isIdentical(X_generic, X, "X_generic""X")) {
524         std::cerr << "CXXLAPACK: X_generic = "
525                   << X_generic << std::endl;
526         std::cerr << "F77LAPACK: X = " << X << std::endl;
527         failed = true;
528     }
529     if (! isIdentical(xNorm_generic, xNorm, "xNorm_generic""xNorm")) {
530         std::cerr << "CXXLAPACK: xNorm_generic = "
531                   << xNorm_generic << std::endl;
532         std::cerr << "F77LAPACK: xNorm = " << xNorm << std::endl;
533         failed = true;
534     }
535     if (! isIdentical(info, _info, " info""_info")) {
536         std::cerr << "CXXLAPACK:  info = " << info << std::endl;
537         std::cerr << "F77LAPACK: _info = " << _info << std::endl;
538         failed = true;
539     }
540 
541     if (failed) {
542         ASSERT(0);
543     }
544 #   endif
545 
546     return info;
547 }
548 
549 //-- forwarding ----------------------------------------------------------------
550 template <typename SIGN, typename MTL, typename MTR, typename MB,
551           typename SCALE, typename MX, typename XNORM>
552 typename MX::IndexType
553 lasy2(bool                  transLeft,
554       bool                  transRight,
555       SIGN                  sign,
556       const MTL             &TL,
557       const MTR             &TR,
558       const MB              &B,
559       SCALE                 &&scale,
560       MX                    &&X,
561       XNORM                 &&xNorm)
562 {
563     typedef typename MX::IndexType IndexType;
564 
565     CHECKPOINT_ENTER;
566     const IndexType info = lasy2(transLeft, transRight, sign,
567                                  TL, TR, B,
568                                  scale, X, xNorm);
569     CHECKPOINT_LEAVE;
570 
571     return info;
572 }
573 
574 } } // namespace lapack, flens
575 
576 #endif // FLENS_LAPACK_EIG_LASY2_TCC