1 /*
2 * Copyright (c) 2011, Michael Lehn
3 *
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 *
10 * 1) Redistributions of source code must retain the above copyright
11 * notice, this list of conditions and the following disclaimer.
12 * 2) Redistributions in binary form must reproduce the above copyright
13 * notice, this list of conditions and the following disclaimer in
14 * the documentation and/or other materials provided with the
15 * distribution.
16 * 3) Neither the name of the FLENS development group nor the names of
17 * its contributors may be used to endorse or promote products derived
18 * from this software without specific prior written permission.
19 *
20 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 /*
34 SUBROUTINE DLATRS( UPLO, TRANS, DIAG, NORMIN, N, A, LDA, X, SCALE,
35 $ CNORM, INFO )
36 *
37 * -- LAPACK auxiliary routine (version 3.2) --
38 * -- LAPACK is a software package provided by Univ. of Tennessee, --
39 * -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
40 * November 2006
41 */
42
43 #ifndef FLENS_LAPACK_AUX_LATRS_TCC
44 #define FLENS_LAPACK_AUX_LATRS_TCC 1
45
46 #include <flens/blas/blas.h>
47 #include <flens/lapack/lapack.h>
48
49 namespace flens { namespace lapack {
50
51 //== generic lapack implementation =============================================
52 template <typename MA, typename VX, typename SCALE, typename CNORM>
53 void
54 latrs_generic(Transpose trans,
55 bool normIn,
56 const TrMatrix<MA> &A,
57 DenseVector<VX> &x,
58 SCALE &scale,
59 DenseVector<CNORM> &cNorm)
60 {
61 using std::abs;
62 using std::max;
63 using std::min;
64
65 typedef typename DenseVector<VX>::ElementType T;
66 typedef typename DenseVector<VX>::IndexType IndexType;
67
68 const T Zero(0), Half(0.5), One(1);
69 const Underscore<IndexType> _;
70 const IndexType n = A.dim();
71 const bool upper = (A.upLo()==Upper);
72 //
73 // Quick return if possible
74 //
75 if (n==0) {
76 return;
77 }
78 //
79 // Determine machine dependent parameters to control overflow.
80 //
81 const T smallNum = lamch<T>(SafeMin) / lamch<T>(Precision);
82 const T bigNum = One / smallNum;
83 scale = One;
84
85 if (!normIn) {
86 //
87 // Compute the 1-norm of each column, not including the diagonal.
88 //
89 if (upper) {
90 //
91 // A is upper triangular.
92 //
93 for (IndexType j=1; j<=n; ++j) {
94 cNorm(j) = blas::asum(A(_(1,j-1),j));
95 }
96 } else {
97 //
98 // A is lower triangular.
99 //
100 for (IndexType j=1; j<=n-1; ++j) {
101 cNorm(j) = blas::asum(A(_(j+1,n),j));
102 }
103 cNorm(n) = Zero;
104 }
105 }
106 //
107 // Scale the column norms by TSCAL if the maximum element in CNORM is
108 // greater than BIGNUM.
109 //
110 const IndexType iMax = blas::iamax(cNorm);
111 const T tMax = cNorm(iMax);
112
113 T tScale;
114 if (tMax<=bigNum) {
115 tScale = One;
116 } else {
117 tScale = One / (smallNum*tMax);
118 cNorm *= tScale;
119 }
120 //
121 // Compute a bound on the computed solution vector to see if the
122 // Level 2 BLAS routine DTRSV can be used.
123 //
124 IndexType jFirst, jLast, jEnd, jInc;
125
126 const IndexType j = blas::iamax(x);
127 T xMax = abs(x(j));
128
129 T xBound = xMax;
130 T grow;
131
132 if (trans==NoTrans) {
133 //
134 // Compute the growth in A * x = b.
135 //
136 if (upper) {
137 jFirst = n;
138 jLast = 1;
139 jInc = -1;
140 } else {
141 jFirst = 1;
142 jLast = n;
143 jInc = 1;
144 }
145 jEnd = jLast + jInc;
146
147 if (tScale!=One) {
148 grow = Zero;
149 } else {
150 if (A.diag()==NonUnit) {
151 //
152 // A is non-unit triangular.
153 //
154 // Compute GROW = 1/G(j) and XBND = 1/M(j).
155 // Initially, G(0) = max{x(i), i=1,...,n}.
156 //
157 grow = One / max(xBound, smallNum);
158 xBound = grow;
159 bool grothFactorTooSmall = false;
160 for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
161 //
162 // Exit the loop if the growth factor is too small.
163 //
164 if (grow<=smallNum) {
165 grothFactorTooSmall = true;
166 break;
167 }
168 //
169 // M(j) = G(j-1) / abs(A(j,j))
170 //
171 const T Tjj = abs(A(j,j));
172 xBound = min(xBound, min(One,Tjj)*grow);
173 if (Tjj+cNorm(j)>=smallNum) {
174 //
175 // G(j) = G(j-1)*( 1 + CNORM(j) / abs(A(j,j)) )
176 //
177 grow *= Tjj / (Tjj+cNorm(j));
178 } else {
179 //
180 // G(j) could overflow, set GROW to 0.
181 //
182 grow = Zero;
183 }
184 }
185 if (!grothFactorTooSmall) {
186 grow = xBound;
187 }
188 } else {
189 //
190 // A is unit triangular.
191 //
192 // Compute GROW = 1/G(j), where G(0) = max{x(i), i=1,...,n}.
193 //
194 grow = min(One, One/max(xBound, smallNum));
195 for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
196 //
197 // Exit the loop if the growth factor is too small.
198 //
199 if (grow<=smallNum) {
200 break;
201 }
202 //
203 // G(j) = G(j-1)*( 1 + CNORM(j) )
204 //
205 grow *= One / (One+cNorm(j));
206 }
207 }
208 }
209 } else {
210 //
211 // Compute the growth in A**T * x = b.
212 //
213 if (upper) {
214 jFirst = 1;
215 jLast = n;
216 jInc = 1;
217 } else {
218 jFirst = n;
219 jLast = 1;
220 jInc = -1;
221 }
222 jEnd = jLast + jInc;
223
224 if (tScale!=One) {
225 grow = Zero;
226 } else {
227 if (A.diag()==NonUnit) {
228 //
229 // A is non-unit triangular.
230 //
231 // Compute GROW = 1/G(j) and XBND = 1/M(j).
232 // Initially, M(0) = max{x(i), i=1,...,n}.
233 //
234 grow = One / max(xBound, smallNum);
235 xBound = grow;
236
237 bool grothFactorTooSmall = false;
238 for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
239 //
240 // Exit the loop if the growth factor is too small.
241 //
242 if (grow<=smallNum) {
243 grothFactorTooSmall = false;
244 break;
245 }
246 //
247 // G(j) = max( G(j-1), M(j-1)*( 1 + CNORM(j) ) )
248 //
249 const T xj = One + cNorm(j);
250 grow = min(grow, xBound / xj);
251 //
252 // M(j) = M(j-1)*( 1 + CNORM(j) ) / abs(A(j,j))
253 //
254 const T Tjj = abs(A(j,j));
255 if (xj>Tjj) {
256 xBound *= Tjj/xj;
257 }
258 }
259 if (!grothFactorTooSmall) {
260 grow = min(grow, xBound);
261 }
262 } else {
263 //
264 // A is unit triangular.
265 //
266 // Compute GROW = 1/G(j), where G(0) = max{x(i), i=1,...,n}.
267 //
268 grow = min(One, One / max(xBound, smallNum));
269 for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
270 //
271 // Exit the loop if the growth factor is too small.
272 //
273 if (grow<=smallNum) {
274 break;
275 }
276 //
277 // G(j) = ( 1 + CNORM(j) )*G(j-1)
278 //
279 const T xj = One + cNorm(j);
280 grow /= xj;
281 }
282 }
283 }
284 }
285
286 if ((grow*tScale)>smallNum) {
287 //
288 // Use the Level 2 BLAS solve if the reciprocal of the bound on
289 // elements of X is not too small.
290 //
291 blas::sv(trans, A, x);
292 } else {
293 //
294 // Use a Level 1 BLAS solve, scaling intermediate results.
295 //
296 if (xMax>bigNum) {
297 //
298 // Scale X so that its components are less than or equal to
299 // BIGNUM in absolute value.
300 //
301 scale = bigNum /xMax;
302 x *= scale;
303 xMax = bigNum;
304 }
305
306 if (trans==NoTrans) {
307 //
308 // Solve A * x = b
309 //
310 for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
311 //
312 // Compute x(j) = b(j) / A(j,j), scaling x if necessary.
313 //
314 T xj = abs(x(j));
315 T TjjS;
316
317 bool skip = false;
318
319 if (A.diag()==NonUnit) {
320 TjjS = A(j,j) * tScale;
321 } else {
322 TjjS = tScale;
323 if (tScale==One) {
324 skip = true;
325 }
326 }
327 if (!skip) {
328 const T Tjj = abs(TjjS);
329 if (Tjj>smallNum) {
330 //
331 // abs(A(j,j)) > SMLNUM:
332 //
333 if (Tjj<One) {
334 if (xj>Tjj*bigNum) {
335 //
336 // Scale x by 1/b(j).
337 //
338 const T rec = One / xj;
339 x *= rec;
340 scale *= rec;
341 xMax *= rec;
342 }
343 }
344 x(j) /= TjjS;
345 xj = abs(x(j));
346 } else if (Tjj>Zero) {
347 //
348 // 0 < abs(A(j,j)) <= SMLNUM:
349 //
350 if (xj>Tjj*bigNum) {
351 //
352 // Scale x by (1/abs(x(j)))*abs(A(j,j))*BIGNUM
353 // to avoid overflow when dividing by A(j,j).
354 //
355 T rec = (Tjj*bigNum) / xj;
356 if (cNorm(j)>One) {
357 //
358 // Scale by 1/CNORM(j) to avoid overflow when
359 // multiplying x(j) times column j.
360 //
361 rec /= cNorm(j);
362 }
363 x *= rec;
364 scale *= rec;
365 xMax *= rec;
366 }
367 x(j) /= TjjS;
368 xj = abs(x(j));
369 } else {
370 //
371 // A(j,j) = 0: Set x(1:n) = 0, x(j) = 1, and
372 // scale = 0, and compute a solution to A*x = 0.
373 //
374 x = Zero;
375 x(j) = One;
376 xj = One;
377 scale = Zero;
378 xMax = Zero;
379 }
380 }
381 //
382 // Scale x if necessary to avoid overflow when adding a
383 // multiple of column j of A.
384 //
385 if (xj>One) {
386 T rec = One / xj;
387 if (cNorm(j)>(bigNum-xMax)*rec) {
388 //
389 // Scale x by 1/(2*abs(x(j))).
390 //
391 rec *= Half;
392 x *= rec;
393 scale *= rec;
394 }
395 } else if (xj*cNorm(j)>(bigNum-xMax)) {
396 //
397 // Scale x by 1/2.
398 //
399 x *= Half;
400 scale *= Half;
401 }
402
403 if (upper) {
404 if (j>1) {
405 //
406 // Compute the update
407 // x(1:j-1) := x(1:j-1) - x(j) * A(1:j-1,j)
408 //
409 x(_(1,j-1)) -= (x(j)*tScale) * A(_(1,j-1),j);
410 const IndexType i = blas::iamax(x(_(1,j-1)));
411 xMax = abs(x(i));
412 }
413 } else {
414 if (j<n) {
415 //
416 // Compute the update
417 // x(j+1:n) := x(j+1:n) - x(j) * A(j+1:n,j)
418 //
419 x(_(j+1,n)) -= (x(j)*tScale) * A(_(j+1,n),j);
420 const IndexType i = j + blas::iamax(x(_(j+1,n)));
421 xMax = abs(x(i));
422 }
423 }
424 }
425
426 } else {
427 //
428 // Solve A**T * x = b
429 //
430 for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
431 //
432 // Compute x(j) = b(j) - sum A(k,j)*x(k).
433 // k<>j
434 //
435 T xj = abs(x(j));
436 T uScale = tScale;
437 T rec = One / max(xMax, One);
438 T TjjS = Zero;
439
440 if (cNorm(j)>(bigNum-xj)*rec) {
441 //
442 // If x(j) could overflow, scale x by 1/(2*XMAX).
443 //
444 rec *= Half;
445 if (A.diag()==NonUnit) {
446 TjjS = A(j,j)*tScale;
447 } else {
448 TjjS = tScale;
449 }
450 const T Tjj = abs(TjjS);
451 if (Tjj>One) {
452 //
453 // Divide by A(j,j) when scaling x if A(j,j) > 1.
454 //
455 rec = min(One, rec*Tjj);
456 uScale /= TjjS;
457 }
458 if (rec<One) {
459 x *= rec;
460 scale *= rec;
461 xMax *= rec;
462 }
463 }
464
465 T sumJ = Zero;
466 if (uScale==One) {
467 //
468 // If the scaling needed for A in the dot product is 1,
469 // call DDOT to perform the dot product.
470 //
471 if (upper) {
472 sumJ = A(_(1,j-1),j) * x(_(1,j-1));
473 } else if (j<n) {
474 sumJ = A(_(j+1,n),j) * x(_(j+1,n));
475 }
476 } else {
477 //
478 // Otherwise, use in-line code for the dot product.
479 //
480 if (upper) {
481 for (IndexType i=1; i<=j-1; ++i) {
482 sumJ += (A(i,j)*uScale)*x(i);
483 }
484 } else if (j<n) {
485 for (IndexType i=j+1; i<=n; ++i) {
486 sumJ += (A(i,j)*uScale)*x(i);
487 }
488 }
489 }
490 if (uScale==tScale) {
491 //
492 // Compute x(j) := ( x(j) - sumj ) / A(j,j) if 1/A(j,j)
493 // was not used to scale the dotproduct.
494 //
495 x(j) -= sumJ;
496 xj = abs(x(j));
497
498 bool skip = false;
499 if (A.diag()==NonUnit) {
500 TjjS = A(j,j)*tScale;
501 } else {
502 TjjS = tScale;
503 if (tScale==One) {
504 skip = true;
505 }
506 }
507 if (!skip) {
508 //
509 // Compute x(j) = x(j) / A(j,j), scaling if necessary.
510 //
511 const T Tjj = abs(TjjS);
512 if (Tjj>smallNum) {
513 //
514 // abs(A(j,j)) > SMLNUM:
515 //
516 if (Tjj<One) {
517 if (xj>Tjj*bigNum) {
518 //
519 // Scale X by 1/abs(x(j)).
520 //
521 rec = One / xj;
522 x *= rec;
523 scale *= rec;
524 xMax *= rec;
525 }
526 }
527 x(j) /= TjjS;
528 } else if (Tjj>Zero) {
529 //
530 // 0 < abs(A(j,j)) <= SMLNUM:
531 //
532 if (xj>Tjj*bigNum) {
533 //
534 // Scale x by (1/abs(x(j)))*abs(A(j,j))*BIGNUM.
535 //
536 const T rec = (Tjj*bigNum) / xj;
537 x *= rec;
538 scale *= rec;
539 xMax *= rec;
540 }
541 x(j) /= TjjS;
542 } else {
543 //
544 // A(j,j) = 0: Set x(1:n) = 0, x(j) = 1, and
545 // scale = 0, and compute a solution to A**T*x = 0.
546 //
547 x = Zero;
548 x(j) = One;
549 scale = Zero;
550 xMax = Zero;
551 }
552 }
553 } else {
554 //
555 // Compute x(j) := x(j) / A(j,j) - sumj if the dot
556 // product has already been divided by 1/A(j,j).
557 //
558 x(j) = x(j)/TjjS - sumJ;
559 }
560 xMax = max(xMax, abs(x(j)));
561 }
562 }
563 scale /= tScale;
564 }
565 //
566 // Scale the column norms by 1/TSCAL for return.
567 //
568 if (tScale!=One) {
569 cNorm *= One/tScale;
570 }
571 }
572
573 //== interface for native lapack ===============================================
574
575 #ifdef CHECK_CXXLAPACK
576
577 template <typename MA, typename VX, typename SCALE, typename CNORM>
578 void
579 latrs_native(Transpose trans,
580 bool normIn,
581 const TrMatrix<MA> &A,
582 DenseVector<VX> &x,
583 SCALE &scale,
584 DenseVector<CNORM> &cNorm)
585 {
586 typedef typename TrMatrix<MA>::ElementType T;
587
588 const char UPLO = cxxblas::getF77BlasChar(A.upLo());
589 const char TRANS = cxxblas::getF77BlasChar(trans);
590 const char DIAG = cxxblas::getF77BlasChar(A.diag());
591 const char NORMIN = (normIn) ? 'Y' : 'N';
592 const INTEGER N = A.dim();
593 const INTEGER LDA = A.leadingDimension();
594 INTEGER INFO;
595
596 if (IsSame<T,double>::value) {
597 LAPACK_IMPL(dlatrs)(&UPLO,
598 &TRANS,
599 &DIAG,
600 &NORMIN,
601 &N,
602 A.data(),
603 &LDA,
604 x.data(),
605 &scale,
606 cNorm.data(),
607 &INFO);
608 } else {
609 ASSERT(0);
610 }
611 ASSERT(INFO==0);
612 }
613
614 #endif // CHECK_CXXLAPACK
615
616 //== public interface ==========================================================
617
618 template <typename MA, typename VX, typename SCALE, typename CNORM>
619 void
620 latrs(Transpose trans,
621 bool normIn,
622 const TrMatrix<MA> &A,
623 DenseVector<VX> &x,
624 SCALE &scale,
625 DenseVector<CNORM> &cNorm)
626 {
627 //
628 // Test the input parameters
629 //
630 # ifndef NDEBUG
631 ASSERT(A.firstRow()==1);
632 ASSERT(A.firstCol()==1);
633 ASSERT(x.firstIndex()==1);
634 ASSERT(x.length()==A.dim());
635 ASSERT(cNorm.firstIndex()==1);
636 ASSERT(cNorm.length()==A.dim());
637 # endif
638
639 # ifdef CHECK_CXXLAPACK
640 //
641 // Make copies of output arguments
642 //
643 typename DenseVector<VX>::NoView x_org = x;
644 SCALE scale_org = scale;
645 typename DenseVector<CNORM>::NoView cNorm_org = cNorm;
646 # endif
647
648 //
649 // Call implementation
650 //
651 latrs_generic(trans, normIn, A, x, scale, cNorm);
652 //latrs_native(trans, normIn, A, x, scale, cNorm);
653
654 # ifdef CHECK_CXXLAPACK
655 //
656 // Compare results
657 //
658 typename DenseVector<VX>::NoView x_generic = x;
659 SCALE scale_generic = scale;
660 typename DenseVector<CNORM>::NoView cNorm_generic = cNorm;
661
662 x = x_org;
663 scale = scale_org;
664 cNorm = cNorm_org;
665
666 latrs_native(trans, normIn, A, x, scale, cNorm);
667
668 bool failed = false;
669 if (! isIdentical(x_generic, x, "x_generic", "x")) {
670 std::cerr << "CXXLAPACK: x_generic = " << x_generic << std::endl;
671 std::cerr << "F77LAPACK: x = " << x << std::endl;
672 failed = true;
673 }
674
675 if (! isIdentical(scale_generic, scale, "scale_generic", "scale")) {
676 std::cerr << "CXXLAPACK: scale_generic = "
677 << scale_generic << std::endl;
678 std::cerr << "F77LAPACK: scale = " << scale << std::endl;
679 failed = true;
680 }
681
682 if (! isIdentical(cNorm_generic, cNorm, "cNorm_generic", "cNorm")) {
683 std::cerr << "CXXLAPACK: cNorm_generic = "
684 << cNorm_generic << std::endl;
685 std::cerr << "F77LAPACK: cNorm = " << cNorm << std::endl;
686 failed = true;
687 }
688
689
690 if (failed) {
691 std::cerr << "x_org = " << x_org << std::endl;
692 std::cerr << "scale_org = " << scale_org << std::endl;
693 ASSERT(0);
694 } else {
695 // std::cerr << "passed: latrs" << std::endl;
696 }
697 # endif
698 }
699
700 //-- forwarding ----------------------------------------------------------------
701 template <typename MA, typename VX, typename SCALE, typename CNORM>
702 void
703 latrs(Transpose trans,
704 bool normIn,
705 const MA &A,
706 VX &&x,
707 SCALE &&scale,
708 CNORM &&cNorm)
709 {
710 CHECKPOINT_ENTER;
711 latrs(trans, normIn, A, x, scale, cNorm);
712 CHECKPOINT_LEAVE;
713 }
714
715 } } // namespace lapack, flens
716
717 #endif // FLENS_LAPACK_AUX_LATRS_TCC
2 * Copyright (c) 2011, Michael Lehn
3 *
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 *
10 * 1) Redistributions of source code must retain the above copyright
11 * notice, this list of conditions and the following disclaimer.
12 * 2) Redistributions in binary form must reproduce the above copyright
13 * notice, this list of conditions and the following disclaimer in
14 * the documentation and/or other materials provided with the
15 * distribution.
16 * 3) Neither the name of the FLENS development group nor the names of
17 * its contributors may be used to endorse or promote products derived
18 * from this software without specific prior written permission.
19 *
20 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 /*
34 SUBROUTINE DLATRS( UPLO, TRANS, DIAG, NORMIN, N, A, LDA, X, SCALE,
35 $ CNORM, INFO )
36 *
37 * -- LAPACK auxiliary routine (version 3.2) --
38 * -- LAPACK is a software package provided by Univ. of Tennessee, --
39 * -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
40 * November 2006
41 */
42
43 #ifndef FLENS_LAPACK_AUX_LATRS_TCC
44 #define FLENS_LAPACK_AUX_LATRS_TCC 1
45
46 #include <flens/blas/blas.h>
47 #include <flens/lapack/lapack.h>
48
49 namespace flens { namespace lapack {
50
51 //== generic lapack implementation =============================================
52 template <typename MA, typename VX, typename SCALE, typename CNORM>
53 void
54 latrs_generic(Transpose trans,
55 bool normIn,
56 const TrMatrix<MA> &A,
57 DenseVector<VX> &x,
58 SCALE &scale,
59 DenseVector<CNORM> &cNorm)
60 {
61 using std::abs;
62 using std::max;
63 using std::min;
64
65 typedef typename DenseVector<VX>::ElementType T;
66 typedef typename DenseVector<VX>::IndexType IndexType;
67
68 const T Zero(0), Half(0.5), One(1);
69 const Underscore<IndexType> _;
70 const IndexType n = A.dim();
71 const bool upper = (A.upLo()==Upper);
72 //
73 // Quick return if possible
74 //
75 if (n==0) {
76 return;
77 }
78 //
79 // Determine machine dependent parameters to control overflow.
80 //
81 const T smallNum = lamch<T>(SafeMin) / lamch<T>(Precision);
82 const T bigNum = One / smallNum;
83 scale = One;
84
85 if (!normIn) {
86 //
87 // Compute the 1-norm of each column, not including the diagonal.
88 //
89 if (upper) {
90 //
91 // A is upper triangular.
92 //
93 for (IndexType j=1; j<=n; ++j) {
94 cNorm(j) = blas::asum(A(_(1,j-1),j));
95 }
96 } else {
97 //
98 // A is lower triangular.
99 //
100 for (IndexType j=1; j<=n-1; ++j) {
101 cNorm(j) = blas::asum(A(_(j+1,n),j));
102 }
103 cNorm(n) = Zero;
104 }
105 }
106 //
107 // Scale the column norms by TSCAL if the maximum element in CNORM is
108 // greater than BIGNUM.
109 //
110 const IndexType iMax = blas::iamax(cNorm);
111 const T tMax = cNorm(iMax);
112
113 T tScale;
114 if (tMax<=bigNum) {
115 tScale = One;
116 } else {
117 tScale = One / (smallNum*tMax);
118 cNorm *= tScale;
119 }
120 //
121 // Compute a bound on the computed solution vector to see if the
122 // Level 2 BLAS routine DTRSV can be used.
123 //
124 IndexType jFirst, jLast, jEnd, jInc;
125
126 const IndexType j = blas::iamax(x);
127 T xMax = abs(x(j));
128
129 T xBound = xMax;
130 T grow;
131
132 if (trans==NoTrans) {
133 //
134 // Compute the growth in A * x = b.
135 //
136 if (upper) {
137 jFirst = n;
138 jLast = 1;
139 jInc = -1;
140 } else {
141 jFirst = 1;
142 jLast = n;
143 jInc = 1;
144 }
145 jEnd = jLast + jInc;
146
147 if (tScale!=One) {
148 grow = Zero;
149 } else {
150 if (A.diag()==NonUnit) {
151 //
152 // A is non-unit triangular.
153 //
154 // Compute GROW = 1/G(j) and XBND = 1/M(j).
155 // Initially, G(0) = max{x(i), i=1,...,n}.
156 //
157 grow = One / max(xBound, smallNum);
158 xBound = grow;
159 bool grothFactorTooSmall = false;
160 for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
161 //
162 // Exit the loop if the growth factor is too small.
163 //
164 if (grow<=smallNum) {
165 grothFactorTooSmall = true;
166 break;
167 }
168 //
169 // M(j) = G(j-1) / abs(A(j,j))
170 //
171 const T Tjj = abs(A(j,j));
172 xBound = min(xBound, min(One,Tjj)*grow);
173 if (Tjj+cNorm(j)>=smallNum) {
174 //
175 // G(j) = G(j-1)*( 1 + CNORM(j) / abs(A(j,j)) )
176 //
177 grow *= Tjj / (Tjj+cNorm(j));
178 } else {
179 //
180 // G(j) could overflow, set GROW to 0.
181 //
182 grow = Zero;
183 }
184 }
185 if (!grothFactorTooSmall) {
186 grow = xBound;
187 }
188 } else {
189 //
190 // A is unit triangular.
191 //
192 // Compute GROW = 1/G(j), where G(0) = max{x(i), i=1,...,n}.
193 //
194 grow = min(One, One/max(xBound, smallNum));
195 for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
196 //
197 // Exit the loop if the growth factor is too small.
198 //
199 if (grow<=smallNum) {
200 break;
201 }
202 //
203 // G(j) = G(j-1)*( 1 + CNORM(j) )
204 //
205 grow *= One / (One+cNorm(j));
206 }
207 }
208 }
209 } else {
210 //
211 // Compute the growth in A**T * x = b.
212 //
213 if (upper) {
214 jFirst = 1;
215 jLast = n;
216 jInc = 1;
217 } else {
218 jFirst = n;
219 jLast = 1;
220 jInc = -1;
221 }
222 jEnd = jLast + jInc;
223
224 if (tScale!=One) {
225 grow = Zero;
226 } else {
227 if (A.diag()==NonUnit) {
228 //
229 // A is non-unit triangular.
230 //
231 // Compute GROW = 1/G(j) and XBND = 1/M(j).
232 // Initially, M(0) = max{x(i), i=1,...,n}.
233 //
234 grow = One / max(xBound, smallNum);
235 xBound = grow;
236
237 bool grothFactorTooSmall = false;
238 for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
239 //
240 // Exit the loop if the growth factor is too small.
241 //
242 if (grow<=smallNum) {
243 grothFactorTooSmall = false;
244 break;
245 }
246 //
247 // G(j) = max( G(j-1), M(j-1)*( 1 + CNORM(j) ) )
248 //
249 const T xj = One + cNorm(j);
250 grow = min(grow, xBound / xj);
251 //
252 // M(j) = M(j-1)*( 1 + CNORM(j) ) / abs(A(j,j))
253 //
254 const T Tjj = abs(A(j,j));
255 if (xj>Tjj) {
256 xBound *= Tjj/xj;
257 }
258 }
259 if (!grothFactorTooSmall) {
260 grow = min(grow, xBound);
261 }
262 } else {
263 //
264 // A is unit triangular.
265 //
266 // Compute GROW = 1/G(j), where G(0) = max{x(i), i=1,...,n}.
267 //
268 grow = min(One, One / max(xBound, smallNum));
269 for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
270 //
271 // Exit the loop if the growth factor is too small.
272 //
273 if (grow<=smallNum) {
274 break;
275 }
276 //
277 // G(j) = ( 1 + CNORM(j) )*G(j-1)
278 //
279 const T xj = One + cNorm(j);
280 grow /= xj;
281 }
282 }
283 }
284 }
285
286 if ((grow*tScale)>smallNum) {
287 //
288 // Use the Level 2 BLAS solve if the reciprocal of the bound on
289 // elements of X is not too small.
290 //
291 blas::sv(trans, A, x);
292 } else {
293 //
294 // Use a Level 1 BLAS solve, scaling intermediate results.
295 //
296 if (xMax>bigNum) {
297 //
298 // Scale X so that its components are less than or equal to
299 // BIGNUM in absolute value.
300 //
301 scale = bigNum /xMax;
302 x *= scale;
303 xMax = bigNum;
304 }
305
306 if (trans==NoTrans) {
307 //
308 // Solve A * x = b
309 //
310 for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
311 //
312 // Compute x(j) = b(j) / A(j,j), scaling x if necessary.
313 //
314 T xj = abs(x(j));
315 T TjjS;
316
317 bool skip = false;
318
319 if (A.diag()==NonUnit) {
320 TjjS = A(j,j) * tScale;
321 } else {
322 TjjS = tScale;
323 if (tScale==One) {
324 skip = true;
325 }
326 }
327 if (!skip) {
328 const T Tjj = abs(TjjS);
329 if (Tjj>smallNum) {
330 //
331 // abs(A(j,j)) > SMLNUM:
332 //
333 if (Tjj<One) {
334 if (xj>Tjj*bigNum) {
335 //
336 // Scale x by 1/b(j).
337 //
338 const T rec = One / xj;
339 x *= rec;
340 scale *= rec;
341 xMax *= rec;
342 }
343 }
344 x(j) /= TjjS;
345 xj = abs(x(j));
346 } else if (Tjj>Zero) {
347 //
348 // 0 < abs(A(j,j)) <= SMLNUM:
349 //
350 if (xj>Tjj*bigNum) {
351 //
352 // Scale x by (1/abs(x(j)))*abs(A(j,j))*BIGNUM
353 // to avoid overflow when dividing by A(j,j).
354 //
355 T rec = (Tjj*bigNum) / xj;
356 if (cNorm(j)>One) {
357 //
358 // Scale by 1/CNORM(j) to avoid overflow when
359 // multiplying x(j) times column j.
360 //
361 rec /= cNorm(j);
362 }
363 x *= rec;
364 scale *= rec;
365 xMax *= rec;
366 }
367 x(j) /= TjjS;
368 xj = abs(x(j));
369 } else {
370 //
371 // A(j,j) = 0: Set x(1:n) = 0, x(j) = 1, and
372 // scale = 0, and compute a solution to A*x = 0.
373 //
374 x = Zero;
375 x(j) = One;
376 xj = One;
377 scale = Zero;
378 xMax = Zero;
379 }
380 }
381 //
382 // Scale x if necessary to avoid overflow when adding a
383 // multiple of column j of A.
384 //
385 if (xj>One) {
386 T rec = One / xj;
387 if (cNorm(j)>(bigNum-xMax)*rec) {
388 //
389 // Scale x by 1/(2*abs(x(j))).
390 //
391 rec *= Half;
392 x *= rec;
393 scale *= rec;
394 }
395 } else if (xj*cNorm(j)>(bigNum-xMax)) {
396 //
397 // Scale x by 1/2.
398 //
399 x *= Half;
400 scale *= Half;
401 }
402
403 if (upper) {
404 if (j>1) {
405 //
406 // Compute the update
407 // x(1:j-1) := x(1:j-1) - x(j) * A(1:j-1,j)
408 //
409 x(_(1,j-1)) -= (x(j)*tScale) * A(_(1,j-1),j);
410 const IndexType i = blas::iamax(x(_(1,j-1)));
411 xMax = abs(x(i));
412 }
413 } else {
414 if (j<n) {
415 //
416 // Compute the update
417 // x(j+1:n) := x(j+1:n) - x(j) * A(j+1:n,j)
418 //
419 x(_(j+1,n)) -= (x(j)*tScale) * A(_(j+1,n),j);
420 const IndexType i = j + blas::iamax(x(_(j+1,n)));
421 xMax = abs(x(i));
422 }
423 }
424 }
425
426 } else {
427 //
428 // Solve A**T * x = b
429 //
430 for (IndexType j=jFirst; j!=jEnd; j+=jInc) {
431 //
432 // Compute x(j) = b(j) - sum A(k,j)*x(k).
433 // k<>j
434 //
435 T xj = abs(x(j));
436 T uScale = tScale;
437 T rec = One / max(xMax, One);
438 T TjjS = Zero;
439
440 if (cNorm(j)>(bigNum-xj)*rec) {
441 //
442 // If x(j) could overflow, scale x by 1/(2*XMAX).
443 //
444 rec *= Half;
445 if (A.diag()==NonUnit) {
446 TjjS = A(j,j)*tScale;
447 } else {
448 TjjS = tScale;
449 }
450 const T Tjj = abs(TjjS);
451 if (Tjj>One) {
452 //
453 // Divide by A(j,j) when scaling x if A(j,j) > 1.
454 //
455 rec = min(One, rec*Tjj);
456 uScale /= TjjS;
457 }
458 if (rec<One) {
459 x *= rec;
460 scale *= rec;
461 xMax *= rec;
462 }
463 }
464
465 T sumJ = Zero;
466 if (uScale==One) {
467 //
468 // If the scaling needed for A in the dot product is 1,
469 // call DDOT to perform the dot product.
470 //
471 if (upper) {
472 sumJ = A(_(1,j-1),j) * x(_(1,j-1));
473 } else if (j<n) {
474 sumJ = A(_(j+1,n),j) * x(_(j+1,n));
475 }
476 } else {
477 //
478 // Otherwise, use in-line code for the dot product.
479 //
480 if (upper) {
481 for (IndexType i=1; i<=j-1; ++i) {
482 sumJ += (A(i,j)*uScale)*x(i);
483 }
484 } else if (j<n) {
485 for (IndexType i=j+1; i<=n; ++i) {
486 sumJ += (A(i,j)*uScale)*x(i);
487 }
488 }
489 }
490 if (uScale==tScale) {
491 //
492 // Compute x(j) := ( x(j) - sumj ) / A(j,j) if 1/A(j,j)
493 // was not used to scale the dotproduct.
494 //
495 x(j) -= sumJ;
496 xj = abs(x(j));
497
498 bool skip = false;
499 if (A.diag()==NonUnit) {
500 TjjS = A(j,j)*tScale;
501 } else {
502 TjjS = tScale;
503 if (tScale==One) {
504 skip = true;
505 }
506 }
507 if (!skip) {
508 //
509 // Compute x(j) = x(j) / A(j,j), scaling if necessary.
510 //
511 const T Tjj = abs(TjjS);
512 if (Tjj>smallNum) {
513 //
514 // abs(A(j,j)) > SMLNUM:
515 //
516 if (Tjj<One) {
517 if (xj>Tjj*bigNum) {
518 //
519 // Scale X by 1/abs(x(j)).
520 //
521 rec = One / xj;
522 x *= rec;
523 scale *= rec;
524 xMax *= rec;
525 }
526 }
527 x(j) /= TjjS;
528 } else if (Tjj>Zero) {
529 //
530 // 0 < abs(A(j,j)) <= SMLNUM:
531 //
532 if (xj>Tjj*bigNum) {
533 //
534 // Scale x by (1/abs(x(j)))*abs(A(j,j))*BIGNUM.
535 //
536 const T rec = (Tjj*bigNum) / xj;
537 x *= rec;
538 scale *= rec;
539 xMax *= rec;
540 }
541 x(j) /= TjjS;
542 } else {
543 //
544 // A(j,j) = 0: Set x(1:n) = 0, x(j) = 1, and
545 // scale = 0, and compute a solution to A**T*x = 0.
546 //
547 x = Zero;
548 x(j) = One;
549 scale = Zero;
550 xMax = Zero;
551 }
552 }
553 } else {
554 //
555 // Compute x(j) := x(j) / A(j,j) - sumj if the dot
556 // product has already been divided by 1/A(j,j).
557 //
558 x(j) = x(j)/TjjS - sumJ;
559 }
560 xMax = max(xMax, abs(x(j)));
561 }
562 }
563 scale /= tScale;
564 }
565 //
566 // Scale the column norms by 1/TSCAL for return.
567 //
568 if (tScale!=One) {
569 cNorm *= One/tScale;
570 }
571 }
572
573 //== interface for native lapack ===============================================
574
575 #ifdef CHECK_CXXLAPACK
576
577 template <typename MA, typename VX, typename SCALE, typename CNORM>
578 void
579 latrs_native(Transpose trans,
580 bool normIn,
581 const TrMatrix<MA> &A,
582 DenseVector<VX> &x,
583 SCALE &scale,
584 DenseVector<CNORM> &cNorm)
585 {
586 typedef typename TrMatrix<MA>::ElementType T;
587
588 const char UPLO = cxxblas::getF77BlasChar(A.upLo());
589 const char TRANS = cxxblas::getF77BlasChar(trans);
590 const char DIAG = cxxblas::getF77BlasChar(A.diag());
591 const char NORMIN = (normIn) ? 'Y' : 'N';
592 const INTEGER N = A.dim();
593 const INTEGER LDA = A.leadingDimension();
594 INTEGER INFO;
595
596 if (IsSame<T,double>::value) {
597 LAPACK_IMPL(dlatrs)(&UPLO,
598 &TRANS,
599 &DIAG,
600 &NORMIN,
601 &N,
602 A.data(),
603 &LDA,
604 x.data(),
605 &scale,
606 cNorm.data(),
607 &INFO);
608 } else {
609 ASSERT(0);
610 }
611 ASSERT(INFO==0);
612 }
613
614 #endif // CHECK_CXXLAPACK
615
616 //== public interface ==========================================================
617
618 template <typename MA, typename VX, typename SCALE, typename CNORM>
619 void
620 latrs(Transpose trans,
621 bool normIn,
622 const TrMatrix<MA> &A,
623 DenseVector<VX> &x,
624 SCALE &scale,
625 DenseVector<CNORM> &cNorm)
626 {
627 //
628 // Test the input parameters
629 //
630 # ifndef NDEBUG
631 ASSERT(A.firstRow()==1);
632 ASSERT(A.firstCol()==1);
633 ASSERT(x.firstIndex()==1);
634 ASSERT(x.length()==A.dim());
635 ASSERT(cNorm.firstIndex()==1);
636 ASSERT(cNorm.length()==A.dim());
637 # endif
638
639 # ifdef CHECK_CXXLAPACK
640 //
641 // Make copies of output arguments
642 //
643 typename DenseVector<VX>::NoView x_org = x;
644 SCALE scale_org = scale;
645 typename DenseVector<CNORM>::NoView cNorm_org = cNorm;
646 # endif
647
648 //
649 // Call implementation
650 //
651 latrs_generic(trans, normIn, A, x, scale, cNorm);
652 //latrs_native(trans, normIn, A, x, scale, cNorm);
653
654 # ifdef CHECK_CXXLAPACK
655 //
656 // Compare results
657 //
658 typename DenseVector<VX>::NoView x_generic = x;
659 SCALE scale_generic = scale;
660 typename DenseVector<CNORM>::NoView cNorm_generic = cNorm;
661
662 x = x_org;
663 scale = scale_org;
664 cNorm = cNorm_org;
665
666 latrs_native(trans, normIn, A, x, scale, cNorm);
667
668 bool failed = false;
669 if (! isIdentical(x_generic, x, "x_generic", "x")) {
670 std::cerr << "CXXLAPACK: x_generic = " << x_generic << std::endl;
671 std::cerr << "F77LAPACK: x = " << x << std::endl;
672 failed = true;
673 }
674
675 if (! isIdentical(scale_generic, scale, "scale_generic", "scale")) {
676 std::cerr << "CXXLAPACK: scale_generic = "
677 << scale_generic << std::endl;
678 std::cerr << "F77LAPACK: scale = " << scale << std::endl;
679 failed = true;
680 }
681
682 if (! isIdentical(cNorm_generic, cNorm, "cNorm_generic", "cNorm")) {
683 std::cerr << "CXXLAPACK: cNorm_generic = "
684 << cNorm_generic << std::endl;
685 std::cerr << "F77LAPACK: cNorm = " << cNorm << std::endl;
686 failed = true;
687 }
688
689
690 if (failed) {
691 std::cerr << "x_org = " << x_org << std::endl;
692 std::cerr << "scale_org = " << scale_org << std::endl;
693 ASSERT(0);
694 } else {
695 // std::cerr << "passed: latrs" << std::endl;
696 }
697 # endif
698 }
699
700 //-- forwarding ----------------------------------------------------------------
701 template <typename MA, typename VX, typename SCALE, typename CNORM>
702 void
703 latrs(Transpose trans,
704 bool normIn,
705 const MA &A,
706 VX &&x,
707 SCALE &&scale,
708 CNORM &&cNorm)
709 {
710 CHECKPOINT_ENTER;
711 latrs(trans, normIn, A, x, scale, cNorm);
712 CHECKPOINT_LEAVE;
713 }
714
715 } } // namespace lapack, flens
716
717 #endif // FLENS_LAPACK_AUX_LATRS_TCC