1 /*
2 * Copyright (c) 2011, Michael Lehn
3 *
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 *
10 * 1) Redistributions of source code must retain the above copyright
11 * notice, this list of conditions and the following disclaimer.
12 * 2) Redistributions in binary form must reproduce the above copyright
13 * notice, this list of conditions and the following disclaimer in
14 * the documentation and/or other materials provided with the
15 * distribution.
16 * 3) Neither the name of the FLENS development group nor the names of
17 * its contributors may be used to endorse or promote products derived
18 * from this software without specific prior written permission.
19 *
20 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 /* Based on
34 *
35 SUBROUTINE DLASY2( LTRANL, LTRANR, ISGN, N1, N2, TL, LDTL, TR,
36 $ LDTR, B, LDB, SCALE, X, LDX, XNORM, INFO )
37 *
38 * -- LAPACK auxiliary routine (version 3.2) --
39 * -- LAPACK is a software package provided by Univ. of Tennessee, --
40 * -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
41 * November 2006
42 */
43
44 #ifndef FLENS_LAPACK_EIG_LASY2_TCC
45 #define FLENS_LAPACK_EIG_LASY2_TCC 1
46
47 #include <flens/blas/blas.h>
48 #include <flens/lapack/lapack.h>
49
50 namespace flens { namespace lapack {
51
52 //== generic lapack implementation =============================================
53
54 template <typename SIGN, typename MTL, typename MTR, typename MB,
55 typename SCALE, typename MX, typename XNORM>
56 typename GeMatrix<MX>::IndexType
57 lasy2_generic(bool transLeft,
58 bool transRight,
59 SIGN iSign,
60 const GeMatrix<MTL> &TL,
61 const GeMatrix<MTR> &TR,
62 const GeMatrix<MB> &B,
63 SCALE &scale,
64 GeMatrix<MX> &X,
65 XNORM &xNorm)
66 {
67 using std::abs;
68 using flens::max;
69 using std::swap;
70
71 typedef typename GeMatrix<MX>::ElementType ElementType;
72 typedef typename GeMatrix<MX>::IndexType IndexType;
73
74 const ElementType Zero(0), Half(0.5), One(1), Two(2), Eight(8);
75
76 const IndexType n1 = TL.numRows();
77 const IndexType n2 = TR.numRows();
78
79 const Underscore<IndexType> _;
80
81 IndexType info = 0;
82 IndexType iPiv;
83
84 bool bSwap, xSwap;
85 ElementType safeMin, beta, gamma, tau1, U11, U12, U22, L21, temp;
86
87 //
88 // .. Local Arrays ..
89 //
90 IndexType _jPivData[4],
91 _locU12Data[4] = { 3, 4, 1, 2},
92 _locL21Data[4] = { 2, 1, 4, 3},
93 _locU22Data[4] = { 4, 3, 2, 1};
94 DenseVectorView<IndexType>
95 jPiv = typename DenseVectorView<IndexType>::Engine(4, _jPivData),
96 locU12 = typename DenseVectorView<IndexType>::Engine(4, _locU12Data),
97 locL21 = typename DenseVectorView<IndexType>::Engine(4, _locL21Data),
98 locU22 = typename DenseVectorView<IndexType>::Engine(4, _locU22Data);
99
100 ElementType _bTmpData[4], _tmpData[4], _x2Data[2];
101 DenseVectorView<ElementType>
102 bTmp = typename DenseVectorView<ElementType>::Engine(4, _bTmpData),
103 tmp = typename DenseVectorView<ElementType>::Engine(4, _tmpData),
104 x2 = typename DenseVectorView<ElementType>::Engine(2, _x2Data);
105
106 bool _xSwapPivData[4] = {false, false, true, true},
107 _bSwapPivData[4] = {false, true, false, true};
108 DenseVectorView<bool>
109 xSwapPiv = typename DenseVectorView<bool>::Engine(4, _xSwapPivData),
110 bSwapPiv = typename DenseVectorView<bool>::Engine(4, _bSwapPivData);
111
112 ElementType _t16Data[16];
113 GeMatrixView<ElementType>
114 T16 = typename GeMatrixView<ElementType>::Engine(4, 4, _t16Data, 4);
115
116 //
117 // Quick return if possible
118 //
119 if (n1==0 || n2==0) {
120 return info;
121 }
122 //
123 // Set constants to control overflow
124 //
125 const ElementType eps = lamch<ElementType>(Precision);
126 const ElementType smallNum = lamch<ElementType>(SafeMin) / eps;
127 const ElementType sign = iSign;
128
129 const IndexType k = n1 + n1 + n2 - 2;
130
131 switch (k) {
132 //
133 // 1 by 1: TL11*X + SGN*X*TR11 = B11
134 //
135 case 1:
136 tau1 = TL(1,1) + sign*TR(1,1);
137 beta = abs(tau1);
138 if (beta<=smallNum) {
139 tau1 = smallNum;
140 beta = smallNum;
141 info = 1;
142 }
143
144 scale = One;
145 gamma = abs(B(1,1));
146 if (smallNum*gamma>beta) {
147 scale = One/gamma;
148 }
149
150 X(1,1) = (B(1,1)*scale) / tau1;
151 xNorm = abs(X(1,1));
152 return info;
153
154 case 2:
155 case 3:
156 if (k==2) {
157 //
158 // 1 by 2:
159 // TL11*[X11 X12] + ISGN*[X11 X12]*op[TR11 TR12] = [B11 B12]
160 // [TR21 TR22]
161 //
162 safeMin = max(eps*max(abs(TL(1,1)), abs(TR(1,1)),
163 abs(TR(1,2)), abs(TR(2,1)),
164 abs(TR(2,2))),
165 smallNum);
166 tmp(1) = TL(1,1) + sign*TR(1,1);
167 tmp(4) = TL(1,1) + sign*TR(2,2);
168 if (transRight) {
169 tmp(2) = sign*TR(2,1);
170 tmp(3) = sign*TR(1,2);
171 } else {
172 tmp(2) = sign*TR(1,2);
173 tmp(3) = sign*TR(2,1);
174 }
175 bTmp(1) = B(1,1);
176 bTmp(2) = B(1,2);
177 } else {
178 //
179 // 2 by 1:
180 // op[TL11 TL12]*[X11] + ISGN* [X11]*TR11 = [B11]
181 // [TL21 TL22] [X21] [X21] [B21]
182 //
183 safeMin = max(eps*max(abs(TR(1,1)), abs(TL(1,1)),
184 abs(TL(1,2)), abs(TL(2,1)),
185 abs(TL(2,2))),
186 smallNum);
187 tmp(1) = TL(1,1) + sign*TR(1,1);
188 tmp(4) = TL(2,2) + sign*TR(1,1);
189 if (transLeft) {
190 tmp(2) = TL(1,2);
191 tmp(3) = TL(2,1);
192 } else {
193 tmp(2) = TL(2,1);
194 tmp(3) = TL(1,2);
195 }
196 bTmp(1) = B(1,1);
197 bTmp(2) = B(2,1);
198 }
199 //
200 // Solve 2 by 2 system using complete pivoting.
201 // Set pivots less than SMIN to SMIN.
202 //
203 iPiv = blas::iamax(tmp);
204 U11 = tmp(iPiv);
205 if (abs(U11)<=safeMin) {
206 info = 1;
207 U11 = safeMin;
208 }
209 U12 = tmp(locU12(iPiv));
210 L21 = tmp(locL21(iPiv)) / U11;
211 U22 = tmp(locU22(iPiv)) - U12*L21;
212 xSwap = xSwapPiv(iPiv);
213 bSwap = bSwapPiv(iPiv);
214 if (abs(U22)<=safeMin) {
215 info = 1;
216 U22 = safeMin;
217 }
218 if (bSwap) {
219 temp = bTmp(2);
220 bTmp(2) = bTmp(1) - L21*temp;
221 bTmp(1) = temp;
222 } else {
223 bTmp(2) -= L21*bTmp(1);
224 }
225 scale = One;
226 if ((Two*smallNum)*abs(bTmp(2))>abs(U22)
227 || (Two*smallNum)*abs(bTmp(1))>abs(U11))
228 {
229 scale = Half / max(abs(bTmp(1)), abs(bTmp(2)));
230 bTmp(1) *= scale;
231 bTmp(2) *= scale;
232 }
233 x2(2) = bTmp(2)/U22;
234 x2(1) = bTmp(1)/U11 - (U12/U11)*x2(2);
235 if (xSwap) {
236 swap(x2(1), x2(2));
237 }
238 X(1,1) = x2(1);
239 if (n1==1) {
240 X(1,2) = x2(2);
241 xNorm = abs(X(1,1)) + abs(X(1,2));
242 } else {
243 X(2,1) = x2(2);
244 xNorm = max(abs(X(1,1)), abs(X(2,1)));
245 }
246 return info;
247
248 //
249 // 2 by 2:
250 // op[TL11 TL12]*[X11 X12] +ISGN* [X11 X12]*op[TR11 TR12] = [B11 B12]
251 // [TL21 TL22] [X21 X22] [X21 X22] [TR21 TR22] [B21 B22]
252 //
253 // Solve equivalent 4 by 4 system using complete pivoting.
254 // Set pivots less than SMIN to SMIN.
255 //
256 case 4:
257 safeMin = max(abs(TR(1,1)), abs(TR(1,2)), abs(TR(2,1)), abs(TR(2,2)));
258 safeMin = max(safeMin, abs(TL(1,1)), abs(TL(1,2)),
259 abs(TL(2,1)), abs(TL(2,2)));
260 safeMin = max(eps*safeMin, smallNum);
261 bTmp(1) = Zero;
262 T16 = 0;
263 T16(1,1) = TL(1,1) + sign*TR(1,1);
264 T16(2,2) = TL(2,2) + sign*TR(1,1);
265 T16(3,3) = TL(1,1) + sign*TR(2,2);
266 T16(4,4) = TL(2,2) + sign*TR(2,2);
267 if (transLeft) {
268 T16(1,2) = TL(2,1);
269 T16(2,1) = TL(1,2);
270 T16(3,4) = TL(2,1);
271 T16(4,3) = TL(1,2);
272 } else {
273 T16(1,2) = TL(1,2);
274 T16(2,1) = TL(2,1);
275 T16(3,4) = TL(1,2);
276 T16(4,3) = TL(2,1);
277 }
278 if (transRight) {
279 T16(1,3) = sign*TR(1,2);
280 T16(2,4) = sign*TR(1,2);
281 T16(3,1) = sign*TR(2,1);
282 T16(4,2) = sign*TR(2,1);
283 } else {
284 T16(1,3) = sign*TR(2,1);
285 T16(2,4) = sign*TR(2,1);
286 T16(3,1) = sign*TR(1,2);
287 T16(4,2) = sign*TR(1,2);
288 }
289 bTmp(1) = B(1,1);
290 bTmp(2) = B(2,1);
291 bTmp(3) = B(1,2);
292 bTmp(4) = B(2,2);
293 //
294 // Perform elimination
295 //
296 for (IndexType i=1; i<=3; ++i) {
297 ElementType xMax = Zero;
298 IndexType ipSv = -1, jpSv = -1;
299
300 for (IndexType ip=i; ip<=4; ++ip) {
301 for (IndexType jp=i; jp<=4; ++jp) {
302 if (abs(T16(ip,jp))>=xMax) {
303 xMax = abs(T16(ip,jp));
304 ipSv = ip;
305 jpSv = jp;
306 }
307 }
308 }
309 if (ipSv!=i) {
310 blas::swap(T16(ipSv,_), T16(i,_));
311 swap(bTmp(i), bTmp(ipSv));
312 }
313 if (jpSv!=i) {
314 blas::swap(T16(_,jpSv), T16(_,i));
315 }
316 jPiv(i) = jpSv;
317 if (abs(T16(i,i))<safeMin) {
318 info = 1;
319 T16(i,i) = safeMin;
320 }
321 for (IndexType j=i+1; j<=4; ++j) {
322 T16(j,i) /= T16(i,i);
323 bTmp(j) -= T16(j,i)*bTmp(i);
324 for (IndexType k=i+1; k<=4; ++k) {
325 T16(j,k) -= T16(j,i)*T16(i,k);
326 }
327 }
328 }
329 if (abs(T16(4,4))<safeMin) {
330 T16(4,4) = safeMin;
331 }
332 scale = One;
333 if ((Eight*smallNum)*abs(bTmp(1))>abs(T16(1,1))
334 || (Eight*smallNum)*abs(bTmp(2))>abs(T16(2,2))
335 || (Eight*smallNum)*abs(bTmp(3))>abs(T16(3,3))
336 || (Eight*smallNum)*abs(bTmp(4))>abs(T16(4,4)))
337 {
338 scale = (One/Eight) / max(abs(bTmp(1)), abs(bTmp(2)),
339 abs(bTmp(3)), abs(bTmp(4)));
340 bTmp(1) *= scale;
341 bTmp(2) *= scale;
342 bTmp(3) *= scale;
343 bTmp(4) *= scale;
344 }
345 for (IndexType i=1; i<=4; ++i) {
346 IndexType k = 5 - i;
347 const ElementType temp = One/T16(k,k);
348 tmp(k) = bTmp(k)*temp;
349 for (IndexType j=k+1; j<=4; ++j) {
350 tmp(k) -= (temp*T16(k,j))*tmp(j);
351 }
352 }
353 for (IndexType i=1; i<=3; ++i) {
354 if (jPiv(4-i)!=4-i) {
355 swap(tmp(4-i), tmp(jPiv(4-i)));
356 }
357 }
358 X(1,1) = tmp(1);
359 X(2,1) = tmp(2);
360 X(1,2) = tmp(3);
361 X(2,2) = tmp(4);
362 xNorm = max(abs(tmp(1))+abs(tmp(3)), abs(tmp(2))+abs(tmp(4)));
363 return info;
364 }
365
366 // error if switch does not handle all cases
367 ASSERT(0);
368 return info;
369 }
370
371 //== interface for native lapack ===============================================
372
373 #ifdef CHECK_CXXLAPACK
374
375 template <typename SIGN, typename MTL, typename MTR, typename MB,
376 typename SCALE, typename MX, typename XNORM>
377 typename GeMatrix<MX>::IndexType
378 lasy2_native(bool transLeft,
379 bool transRight,
380 SIGN sign,
381 const GeMatrix<MTL> &TL,
382 const GeMatrix<MTR> &TR,
383 const GeMatrix<MB> &B,
384 SCALE &scale,
385 GeMatrix<MX> &X,
386 XNORM &xNorm)
387 {
388 typedef typename GeMatrix<MX>::ElementType ElementType;
389
390 const LOGICAL LTRANL = transLeft;
391 const LOGICAL LTRANR = transRight;
392 const INTEGER ISGN = sign;
393 const INTEGER N1 = TL.numRows();
394 const INTEGER N2 = TR.numRows();
395 const INTEGER LDTL = TL.leadingDimension();
396 const INTEGER LDTR = TR.leadingDimension();
397 const INTEGER LDB = B.leadingDimension();
398 ElementType _SCALE = scale;
399 const INTEGER LDX = X.leadingDimension();
400 ElementType _XNORM = xNorm;
401 INTEGER INFO;
402
403 if (IsSame<ElementType,DOUBLE>::value) {
404 LAPACK_IMPL(dlasy2)(<RANL,
405 <RANR,
406 &ISGN,
407 &N1,
408 &N2,
409 TL.data(),
410 &LDTL,
411 TR.data(),
412 &LDTR,
413 B.data(),
414 &LDB,
415 &_SCALE,
416 X.data(),
417 &LDX,
418 &_XNORM,
419 &INFO);
420 } else {
421 ASSERT(0);
422 }
423 ASSERT(INFO>=0);
424
425 scale = _SCALE;
426 xNorm = _XNORM;
427
428 return INFO;
429 }
430
431 #endif // CHECK_CXXLAPACK
432
433 //== public interface ==========================================================
434
435 template <typename SIGN, typename MTL, typename MTR, typename MB,
436 typename SCALE, typename MX, typename XNORM>
437 typename GeMatrix<MX>::IndexType
438 lasy2(bool transLeft,
439 bool transRight,
440 SIGN sign,
441 const GeMatrix<MTL> &TL,
442 const GeMatrix<MTR> &TR,
443 const GeMatrix<MB> &B,
444 SCALE &scale,
445 GeMatrix<MX> &X,
446 XNORM &xNorm)
447 {
448 LAPACK_DEBUG_OUT("lasy2");
449
450 typedef typename GeMatrix<MX>::IndexType IndexType;
451 //
452 // Test the input parameters
453 //
454 # ifndef NDEBUG
455 ASSERT(sign==1 || sign==-1);
456
457 ASSERT(TL.firstRow()==1);
458 ASSERT(TL.firstCol()==1);
459 ASSERT(TL.numRows()==TL.numCols());
460 ASSERT(TL.numRows()<=2);
461
462 ASSERT(TR.firstRow()==1);
463 ASSERT(TR.firstCol()==1);
464 ASSERT(TR.numRows()==TR.numCols());
465 ASSERT(TR.numRows()<=2);
466
467 ASSERT(B.firstRow()==1);
468 ASSERT(B.firstCol()==1);
469 ASSERT(B.numRows()==TL.numRows());
470 ASSERT(B.numCols()==TR.numRows());
471
472 ASSERT(X.firstRow()==1);
473 ASSERT(X.firstCol()==1);
474 ASSERT(X.numRows()==TL.numRows());
475 ASSERT(X.numCols()==TR.numRows());
476 # endif
477
478 # ifdef CHECK_CXXLAPACK
479 //
480 // Make copies of output arguments
481 //
482 SCALE scale_org = scale;
483 typename GeMatrix<MX>::NoView X_org = X;
484 XNORM xNorm_org = xNorm;
485 # endif
486
487 //
488 // Call implementation
489 //
490 IndexType info = lasy2_generic(transLeft, transRight, sign,
491 TL, TR, B,
492 scale, X, xNorm);
493 # ifdef CHECK_CXXLAPACK
494 //
495 // Make copies of results computed by the generic implementation
496 //
497 SCALE scale_generic = scale;
498 typename GeMatrix<MX>::NoView X_generic = X;
499 XNORM xNorm_generic = xNorm;
500
501 //
502 // restore output arguments
503 //
504 scale = scale_org;
505 X = X_org;
506 xNorm = xNorm_org;
507
508 //
509 // Compare generic results with results from the native implementation
510 //
511
512 IndexType _info = lasy2_native(transLeft, transRight, sign,
513 TL, TR, B,
514 scale, X, xNorm);
515
516 bool failed = false;
517 if (! isIdentical(scale_generic, scale, "scale_generic", "scale")) {
518 std::cerr << "CXXLAPACK: scale_generic = "
519 << scale_generic << std::endl;
520 std::cerr << "F77LAPACK: scale = " << scale << std::endl;
521 failed = true;
522 }
523 if (! isIdentical(X_generic, X, "X_generic", "X")) {
524 std::cerr << "CXXLAPACK: X_generic = "
525 << X_generic << std::endl;
526 std::cerr << "F77LAPACK: X = " << X << std::endl;
527 failed = true;
528 }
529 if (! isIdentical(xNorm_generic, xNorm, "xNorm_generic", "xNorm")) {
530 std::cerr << "CXXLAPACK: xNorm_generic = "
531 << xNorm_generic << std::endl;
532 std::cerr << "F77LAPACK: xNorm = " << xNorm << std::endl;
533 failed = true;
534 }
535 if (! isIdentical(info, _info, " info", "_info")) {
536 std::cerr << "CXXLAPACK: info = " << info << std::endl;
537 std::cerr << "F77LAPACK: _info = " << _info << std::endl;
538 failed = true;
539 }
540
541 if (failed) {
542 ASSERT(0);
543 }
544 # endif
545
546 return info;
547 }
548
549 //-- forwarding ----------------------------------------------------------------
550 template <typename SIGN, typename MTL, typename MTR, typename MB,
551 typename SCALE, typename MX, typename XNORM>
552 typename MX::IndexType
553 lasy2(bool transLeft,
554 bool transRight,
555 SIGN sign,
556 const MTL &TL,
557 const MTR &TR,
558 const MB &B,
559 SCALE &&scale,
560 MX &&X,
561 XNORM &&xNorm)
562 {
563 typedef typename MX::IndexType IndexType;
564
565 CHECKPOINT_ENTER;
566 const IndexType info = lasy2(transLeft, transRight, sign,
567 TL, TR, B,
568 scale, X, xNorm);
569 CHECKPOINT_LEAVE;
570
571 return info;
572 }
573
574 } } // namespace lapack, flens
575
576 #endif // FLENS_LAPACK_EIG_LASY2_TCC
2 * Copyright (c) 2011, Michael Lehn
3 *
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 *
10 * 1) Redistributions of source code must retain the above copyright
11 * notice, this list of conditions and the following disclaimer.
12 * 2) Redistributions in binary form must reproduce the above copyright
13 * notice, this list of conditions and the following disclaimer in
14 * the documentation and/or other materials provided with the
15 * distribution.
16 * 3) Neither the name of the FLENS development group nor the names of
17 * its contributors may be used to endorse or promote products derived
18 * from this software without specific prior written permission.
19 *
20 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 /* Based on
34 *
35 SUBROUTINE DLASY2( LTRANL, LTRANR, ISGN, N1, N2, TL, LDTL, TR,
36 $ LDTR, B, LDB, SCALE, X, LDX, XNORM, INFO )
37 *
38 * -- LAPACK auxiliary routine (version 3.2) --
39 * -- LAPACK is a software package provided by Univ. of Tennessee, --
40 * -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
41 * November 2006
42 */
43
44 #ifndef FLENS_LAPACK_EIG_LASY2_TCC
45 #define FLENS_LAPACK_EIG_LASY2_TCC 1
46
47 #include <flens/blas/blas.h>
48 #include <flens/lapack/lapack.h>
49
50 namespace flens { namespace lapack {
51
52 //== generic lapack implementation =============================================
53
54 template <typename SIGN, typename MTL, typename MTR, typename MB,
55 typename SCALE, typename MX, typename XNORM>
56 typename GeMatrix<MX>::IndexType
57 lasy2_generic(bool transLeft,
58 bool transRight,
59 SIGN iSign,
60 const GeMatrix<MTL> &TL,
61 const GeMatrix<MTR> &TR,
62 const GeMatrix<MB> &B,
63 SCALE &scale,
64 GeMatrix<MX> &X,
65 XNORM &xNorm)
66 {
67 using std::abs;
68 using flens::max;
69 using std::swap;
70
71 typedef typename GeMatrix<MX>::ElementType ElementType;
72 typedef typename GeMatrix<MX>::IndexType IndexType;
73
74 const ElementType Zero(0), Half(0.5), One(1), Two(2), Eight(8);
75
76 const IndexType n1 = TL.numRows();
77 const IndexType n2 = TR.numRows();
78
79 const Underscore<IndexType> _;
80
81 IndexType info = 0;
82 IndexType iPiv;
83
84 bool bSwap, xSwap;
85 ElementType safeMin, beta, gamma, tau1, U11, U12, U22, L21, temp;
86
87 //
88 // .. Local Arrays ..
89 //
90 IndexType _jPivData[4],
91 _locU12Data[4] = { 3, 4, 1, 2},
92 _locL21Data[4] = { 2, 1, 4, 3},
93 _locU22Data[4] = { 4, 3, 2, 1};
94 DenseVectorView<IndexType>
95 jPiv = typename DenseVectorView<IndexType>::Engine(4, _jPivData),
96 locU12 = typename DenseVectorView<IndexType>::Engine(4, _locU12Data),
97 locL21 = typename DenseVectorView<IndexType>::Engine(4, _locL21Data),
98 locU22 = typename DenseVectorView<IndexType>::Engine(4, _locU22Data);
99
100 ElementType _bTmpData[4], _tmpData[4], _x2Data[2];
101 DenseVectorView<ElementType>
102 bTmp = typename DenseVectorView<ElementType>::Engine(4, _bTmpData),
103 tmp = typename DenseVectorView<ElementType>::Engine(4, _tmpData),
104 x2 = typename DenseVectorView<ElementType>::Engine(2, _x2Data);
105
106 bool _xSwapPivData[4] = {false, false, true, true},
107 _bSwapPivData[4] = {false, true, false, true};
108 DenseVectorView<bool>
109 xSwapPiv = typename DenseVectorView<bool>::Engine(4, _xSwapPivData),
110 bSwapPiv = typename DenseVectorView<bool>::Engine(4, _bSwapPivData);
111
112 ElementType _t16Data[16];
113 GeMatrixView<ElementType>
114 T16 = typename GeMatrixView<ElementType>::Engine(4, 4, _t16Data, 4);
115
116 //
117 // Quick return if possible
118 //
119 if (n1==0 || n2==0) {
120 return info;
121 }
122 //
123 // Set constants to control overflow
124 //
125 const ElementType eps = lamch<ElementType>(Precision);
126 const ElementType smallNum = lamch<ElementType>(SafeMin) / eps;
127 const ElementType sign = iSign;
128
129 const IndexType k = n1 + n1 + n2 - 2;
130
131 switch (k) {
132 //
133 // 1 by 1: TL11*X + SGN*X*TR11 = B11
134 //
135 case 1:
136 tau1 = TL(1,1) + sign*TR(1,1);
137 beta = abs(tau1);
138 if (beta<=smallNum) {
139 tau1 = smallNum;
140 beta = smallNum;
141 info = 1;
142 }
143
144 scale = One;
145 gamma = abs(B(1,1));
146 if (smallNum*gamma>beta) {
147 scale = One/gamma;
148 }
149
150 X(1,1) = (B(1,1)*scale) / tau1;
151 xNorm = abs(X(1,1));
152 return info;
153
154 case 2:
155 case 3:
156 if (k==2) {
157 //
158 // 1 by 2:
159 // TL11*[X11 X12] + ISGN*[X11 X12]*op[TR11 TR12] = [B11 B12]
160 // [TR21 TR22]
161 //
162 safeMin = max(eps*max(abs(TL(1,1)), abs(TR(1,1)),
163 abs(TR(1,2)), abs(TR(2,1)),
164 abs(TR(2,2))),
165 smallNum);
166 tmp(1) = TL(1,1) + sign*TR(1,1);
167 tmp(4) = TL(1,1) + sign*TR(2,2);
168 if (transRight) {
169 tmp(2) = sign*TR(2,1);
170 tmp(3) = sign*TR(1,2);
171 } else {
172 tmp(2) = sign*TR(1,2);
173 tmp(3) = sign*TR(2,1);
174 }
175 bTmp(1) = B(1,1);
176 bTmp(2) = B(1,2);
177 } else {
178 //
179 // 2 by 1:
180 // op[TL11 TL12]*[X11] + ISGN* [X11]*TR11 = [B11]
181 // [TL21 TL22] [X21] [X21] [B21]
182 //
183 safeMin = max(eps*max(abs(TR(1,1)), abs(TL(1,1)),
184 abs(TL(1,2)), abs(TL(2,1)),
185 abs(TL(2,2))),
186 smallNum);
187 tmp(1) = TL(1,1) + sign*TR(1,1);
188 tmp(4) = TL(2,2) + sign*TR(1,1);
189 if (transLeft) {
190 tmp(2) = TL(1,2);
191 tmp(3) = TL(2,1);
192 } else {
193 tmp(2) = TL(2,1);
194 tmp(3) = TL(1,2);
195 }
196 bTmp(1) = B(1,1);
197 bTmp(2) = B(2,1);
198 }
199 //
200 // Solve 2 by 2 system using complete pivoting.
201 // Set pivots less than SMIN to SMIN.
202 //
203 iPiv = blas::iamax(tmp);
204 U11 = tmp(iPiv);
205 if (abs(U11)<=safeMin) {
206 info = 1;
207 U11 = safeMin;
208 }
209 U12 = tmp(locU12(iPiv));
210 L21 = tmp(locL21(iPiv)) / U11;
211 U22 = tmp(locU22(iPiv)) - U12*L21;
212 xSwap = xSwapPiv(iPiv);
213 bSwap = bSwapPiv(iPiv);
214 if (abs(U22)<=safeMin) {
215 info = 1;
216 U22 = safeMin;
217 }
218 if (bSwap) {
219 temp = bTmp(2);
220 bTmp(2) = bTmp(1) - L21*temp;
221 bTmp(1) = temp;
222 } else {
223 bTmp(2) -= L21*bTmp(1);
224 }
225 scale = One;
226 if ((Two*smallNum)*abs(bTmp(2))>abs(U22)
227 || (Two*smallNum)*abs(bTmp(1))>abs(U11))
228 {
229 scale = Half / max(abs(bTmp(1)), abs(bTmp(2)));
230 bTmp(1) *= scale;
231 bTmp(2) *= scale;
232 }
233 x2(2) = bTmp(2)/U22;
234 x2(1) = bTmp(1)/U11 - (U12/U11)*x2(2);
235 if (xSwap) {
236 swap(x2(1), x2(2));
237 }
238 X(1,1) = x2(1);
239 if (n1==1) {
240 X(1,2) = x2(2);
241 xNorm = abs(X(1,1)) + abs(X(1,2));
242 } else {
243 X(2,1) = x2(2);
244 xNorm = max(abs(X(1,1)), abs(X(2,1)));
245 }
246 return info;
247
248 //
249 // 2 by 2:
250 // op[TL11 TL12]*[X11 X12] +ISGN* [X11 X12]*op[TR11 TR12] = [B11 B12]
251 // [TL21 TL22] [X21 X22] [X21 X22] [TR21 TR22] [B21 B22]
252 //
253 // Solve equivalent 4 by 4 system using complete pivoting.
254 // Set pivots less than SMIN to SMIN.
255 //
256 case 4:
257 safeMin = max(abs(TR(1,1)), abs(TR(1,2)), abs(TR(2,1)), abs(TR(2,2)));
258 safeMin = max(safeMin, abs(TL(1,1)), abs(TL(1,2)),
259 abs(TL(2,1)), abs(TL(2,2)));
260 safeMin = max(eps*safeMin, smallNum);
261 bTmp(1) = Zero;
262 T16 = 0;
263 T16(1,1) = TL(1,1) + sign*TR(1,1);
264 T16(2,2) = TL(2,2) + sign*TR(1,1);
265 T16(3,3) = TL(1,1) + sign*TR(2,2);
266 T16(4,4) = TL(2,2) + sign*TR(2,2);
267 if (transLeft) {
268 T16(1,2) = TL(2,1);
269 T16(2,1) = TL(1,2);
270 T16(3,4) = TL(2,1);
271 T16(4,3) = TL(1,2);
272 } else {
273 T16(1,2) = TL(1,2);
274 T16(2,1) = TL(2,1);
275 T16(3,4) = TL(1,2);
276 T16(4,3) = TL(2,1);
277 }
278 if (transRight) {
279 T16(1,3) = sign*TR(1,2);
280 T16(2,4) = sign*TR(1,2);
281 T16(3,1) = sign*TR(2,1);
282 T16(4,2) = sign*TR(2,1);
283 } else {
284 T16(1,3) = sign*TR(2,1);
285 T16(2,4) = sign*TR(2,1);
286 T16(3,1) = sign*TR(1,2);
287 T16(4,2) = sign*TR(1,2);
288 }
289 bTmp(1) = B(1,1);
290 bTmp(2) = B(2,1);
291 bTmp(3) = B(1,2);
292 bTmp(4) = B(2,2);
293 //
294 // Perform elimination
295 //
296 for (IndexType i=1; i<=3; ++i) {
297 ElementType xMax = Zero;
298 IndexType ipSv = -1, jpSv = -1;
299
300 for (IndexType ip=i; ip<=4; ++ip) {
301 for (IndexType jp=i; jp<=4; ++jp) {
302 if (abs(T16(ip,jp))>=xMax) {
303 xMax = abs(T16(ip,jp));
304 ipSv = ip;
305 jpSv = jp;
306 }
307 }
308 }
309 if (ipSv!=i) {
310 blas::swap(T16(ipSv,_), T16(i,_));
311 swap(bTmp(i), bTmp(ipSv));
312 }
313 if (jpSv!=i) {
314 blas::swap(T16(_,jpSv), T16(_,i));
315 }
316 jPiv(i) = jpSv;
317 if (abs(T16(i,i))<safeMin) {
318 info = 1;
319 T16(i,i) = safeMin;
320 }
321 for (IndexType j=i+1; j<=4; ++j) {
322 T16(j,i) /= T16(i,i);
323 bTmp(j) -= T16(j,i)*bTmp(i);
324 for (IndexType k=i+1; k<=4; ++k) {
325 T16(j,k) -= T16(j,i)*T16(i,k);
326 }
327 }
328 }
329 if (abs(T16(4,4))<safeMin) {
330 T16(4,4) = safeMin;
331 }
332 scale = One;
333 if ((Eight*smallNum)*abs(bTmp(1))>abs(T16(1,1))
334 || (Eight*smallNum)*abs(bTmp(2))>abs(T16(2,2))
335 || (Eight*smallNum)*abs(bTmp(3))>abs(T16(3,3))
336 || (Eight*smallNum)*abs(bTmp(4))>abs(T16(4,4)))
337 {
338 scale = (One/Eight) / max(abs(bTmp(1)), abs(bTmp(2)),
339 abs(bTmp(3)), abs(bTmp(4)));
340 bTmp(1) *= scale;
341 bTmp(2) *= scale;
342 bTmp(3) *= scale;
343 bTmp(4) *= scale;
344 }
345 for (IndexType i=1; i<=4; ++i) {
346 IndexType k = 5 - i;
347 const ElementType temp = One/T16(k,k);
348 tmp(k) = bTmp(k)*temp;
349 for (IndexType j=k+1; j<=4; ++j) {
350 tmp(k) -= (temp*T16(k,j))*tmp(j);
351 }
352 }
353 for (IndexType i=1; i<=3; ++i) {
354 if (jPiv(4-i)!=4-i) {
355 swap(tmp(4-i), tmp(jPiv(4-i)));
356 }
357 }
358 X(1,1) = tmp(1);
359 X(2,1) = tmp(2);
360 X(1,2) = tmp(3);
361 X(2,2) = tmp(4);
362 xNorm = max(abs(tmp(1))+abs(tmp(3)), abs(tmp(2))+abs(tmp(4)));
363 return info;
364 }
365
366 // error if switch does not handle all cases
367 ASSERT(0);
368 return info;
369 }
370
371 //== interface for native lapack ===============================================
372
373 #ifdef CHECK_CXXLAPACK
374
375 template <typename SIGN, typename MTL, typename MTR, typename MB,
376 typename SCALE, typename MX, typename XNORM>
377 typename GeMatrix<MX>::IndexType
378 lasy2_native(bool transLeft,
379 bool transRight,
380 SIGN sign,
381 const GeMatrix<MTL> &TL,
382 const GeMatrix<MTR> &TR,
383 const GeMatrix<MB> &B,
384 SCALE &scale,
385 GeMatrix<MX> &X,
386 XNORM &xNorm)
387 {
388 typedef typename GeMatrix<MX>::ElementType ElementType;
389
390 const LOGICAL LTRANL = transLeft;
391 const LOGICAL LTRANR = transRight;
392 const INTEGER ISGN = sign;
393 const INTEGER N1 = TL.numRows();
394 const INTEGER N2 = TR.numRows();
395 const INTEGER LDTL = TL.leadingDimension();
396 const INTEGER LDTR = TR.leadingDimension();
397 const INTEGER LDB = B.leadingDimension();
398 ElementType _SCALE = scale;
399 const INTEGER LDX = X.leadingDimension();
400 ElementType _XNORM = xNorm;
401 INTEGER INFO;
402
403 if (IsSame<ElementType,DOUBLE>::value) {
404 LAPACK_IMPL(dlasy2)(<RANL,
405 <RANR,
406 &ISGN,
407 &N1,
408 &N2,
409 TL.data(),
410 &LDTL,
411 TR.data(),
412 &LDTR,
413 B.data(),
414 &LDB,
415 &_SCALE,
416 X.data(),
417 &LDX,
418 &_XNORM,
419 &INFO);
420 } else {
421 ASSERT(0);
422 }
423 ASSERT(INFO>=0);
424
425 scale = _SCALE;
426 xNorm = _XNORM;
427
428 return INFO;
429 }
430
431 #endif // CHECK_CXXLAPACK
432
433 //== public interface ==========================================================
434
435 template <typename SIGN, typename MTL, typename MTR, typename MB,
436 typename SCALE, typename MX, typename XNORM>
437 typename GeMatrix<MX>::IndexType
438 lasy2(bool transLeft,
439 bool transRight,
440 SIGN sign,
441 const GeMatrix<MTL> &TL,
442 const GeMatrix<MTR> &TR,
443 const GeMatrix<MB> &B,
444 SCALE &scale,
445 GeMatrix<MX> &X,
446 XNORM &xNorm)
447 {
448 LAPACK_DEBUG_OUT("lasy2");
449
450 typedef typename GeMatrix<MX>::IndexType IndexType;
451 //
452 // Test the input parameters
453 //
454 # ifndef NDEBUG
455 ASSERT(sign==1 || sign==-1);
456
457 ASSERT(TL.firstRow()==1);
458 ASSERT(TL.firstCol()==1);
459 ASSERT(TL.numRows()==TL.numCols());
460 ASSERT(TL.numRows()<=2);
461
462 ASSERT(TR.firstRow()==1);
463 ASSERT(TR.firstCol()==1);
464 ASSERT(TR.numRows()==TR.numCols());
465 ASSERT(TR.numRows()<=2);
466
467 ASSERT(B.firstRow()==1);
468 ASSERT(B.firstCol()==1);
469 ASSERT(B.numRows()==TL.numRows());
470 ASSERT(B.numCols()==TR.numRows());
471
472 ASSERT(X.firstRow()==1);
473 ASSERT(X.firstCol()==1);
474 ASSERT(X.numRows()==TL.numRows());
475 ASSERT(X.numCols()==TR.numRows());
476 # endif
477
478 # ifdef CHECK_CXXLAPACK
479 //
480 // Make copies of output arguments
481 //
482 SCALE scale_org = scale;
483 typename GeMatrix<MX>::NoView X_org = X;
484 XNORM xNorm_org = xNorm;
485 # endif
486
487 //
488 // Call implementation
489 //
490 IndexType info = lasy2_generic(transLeft, transRight, sign,
491 TL, TR, B,
492 scale, X, xNorm);
493 # ifdef CHECK_CXXLAPACK
494 //
495 // Make copies of results computed by the generic implementation
496 //
497 SCALE scale_generic = scale;
498 typename GeMatrix<MX>::NoView X_generic = X;
499 XNORM xNorm_generic = xNorm;
500
501 //
502 // restore output arguments
503 //
504 scale = scale_org;
505 X = X_org;
506 xNorm = xNorm_org;
507
508 //
509 // Compare generic results with results from the native implementation
510 //
511
512 IndexType _info = lasy2_native(transLeft, transRight, sign,
513 TL, TR, B,
514 scale, X, xNorm);
515
516 bool failed = false;
517 if (! isIdentical(scale_generic, scale, "scale_generic", "scale")) {
518 std::cerr << "CXXLAPACK: scale_generic = "
519 << scale_generic << std::endl;
520 std::cerr << "F77LAPACK: scale = " << scale << std::endl;
521 failed = true;
522 }
523 if (! isIdentical(X_generic, X, "X_generic", "X")) {
524 std::cerr << "CXXLAPACK: X_generic = "
525 << X_generic << std::endl;
526 std::cerr << "F77LAPACK: X = " << X << std::endl;
527 failed = true;
528 }
529 if (! isIdentical(xNorm_generic, xNorm, "xNorm_generic", "xNorm")) {
530 std::cerr << "CXXLAPACK: xNorm_generic = "
531 << xNorm_generic << std::endl;
532 std::cerr << "F77LAPACK: xNorm = " << xNorm << std::endl;
533 failed = true;
534 }
535 if (! isIdentical(info, _info, " info", "_info")) {
536 std::cerr << "CXXLAPACK: info = " << info << std::endl;
537 std::cerr << "F77LAPACK: _info = " << _info << std::endl;
538 failed = true;
539 }
540
541 if (failed) {
542 ASSERT(0);
543 }
544 # endif
545
546 return info;
547 }
548
549 //-- forwarding ----------------------------------------------------------------
550 template <typename SIGN, typename MTL, typename MTR, typename MB,
551 typename SCALE, typename MX, typename XNORM>
552 typename MX::IndexType
553 lasy2(bool transLeft,
554 bool transRight,
555 SIGN sign,
556 const MTL &TL,
557 const MTR &TR,
558 const MB &B,
559 SCALE &&scale,
560 MX &&X,
561 XNORM &&xNorm)
562 {
563 typedef typename MX::IndexType IndexType;
564
565 CHECKPOINT_ENTER;
566 const IndexType info = lasy2(transLeft, transRight, sign,
567 TL, TR, B,
568 scale, X, xNorm);
569 CHECKPOINT_LEAVE;
570
571 return info;
572 }
573
574 } } // namespace lapack, flens
575
576 #endif // FLENS_LAPACK_EIG_LASY2_TCC