1 /*
  2  *   Copyright (c) 2011, Michael Lehn
  3  *
  4  *   All rights reserved.
  5  *
  6  *   Redistribution and use in source and binary forms, with or without
  7  *   modification, are permitted provided that the following conditions
  8  *   are met:
  9  *
 10  *   1) Redistributions of source code must retain the above copyright
 11  *      notice, this list of conditions and the following disclaimer.
 12  *   2) Redistributions in binary form must reproduce the above copyright
 13  *      notice, this list of conditions and the following disclaimer in
 14  *      the documentation and/or other materials provided with the
 15  *      distribution.
 16  *   3) Neither the name of the FLENS development group nor the names of
 17  *      its contributors may be used to endorse or promote products derived
 18  *      from this software without specific prior written permission.
 19  *
 20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 21  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 22  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 23  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 24  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 26  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 27  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 28  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 29  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 /* Baesed on
 34  *
 35       SUBROUTINE DGETRI( N, A, LDA, IPIV, WORK, LWORK, INFO )
 36 
 37       SUBROUTINE DTRTRI( UPLO, DIAG, N, A, LDA, INFO )
 38  *
 39  *  -- LAPACK routine (version 3.2) --
 40  *  -- LAPACK is a software package provided by Univ. of Tennessee,    --
 41  *  -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
 42  *     November 2006
 43  */
 44 
 45 #ifndef FLENS_LAPACK_GESV_TRI_TCC
 46 #define FLENS_LAPACK_GESV_TRI_TCC 1
 47 
 48 #include <algorithm>
 49 #include <flens/blas/blas.h>
 50 #include <flens/lapack/lapack.h>
 51 
 52 #include <flens/lapack/interface/include/f77lapack.h>
 53 
 54 namespace flens { namespace lapack {
 55 
 56 //== generic lapack implementation =============================================
 57 //-- (ge)tri
 58 template <typename MA, typename VP, typename VWORK>
 59 typename GeMatrix<MA>::IndexType
 60 tri_generic(GeMatrix<MA>            &A,
 61             const DenseVector<VP>   &piv,
 62             DenseVector<VWORK>      &work)
 63 {
 64     using std::max;
 65 
 66     typedef typename GeMatrix<MA>::ElementType ElementType;
 67     typedef typename GeMatrix<MA>::IndexType   IndexType;
 68 
 69     const ElementType Zero(0), One(1);
 70     const IndexType n = A.numRows();
 71     const Underscore<IndexType> _;
 72 
 73     IndexType info = 0;
 74     IndexType nb = ilaenv<ElementType>(1"GETRI""", n);
 75 
 76     const IndexType lWorkOpt = n*nb;
 77 
 78     if (work.length()==0) {
 79         work.resize(max(lWorkOpt, IndexType(1)));
 80     }
 81     work(1) = lWorkOpt;
 82 
 83 //
 84 //  Quick return if possible
 85 //
 86     if (n==0) {
 87         return info;
 88     }
 89 //
 90 //  Form inv(U).  If INFO > 0 from DTRTRI, then U is singular,
 91 //  and the inverse is not computed.
 92 //
 93     info = tri(A.upper());
 94     if (info>0) {
 95         return info;
 96     }
 97 
 98     IndexType nbMin = 2;
 99     const IndexType lWork  = work.length();;
100     const IndexType ldWork = n;
101 
102     IndexType iws;
103 
104     if (nb>1 && nb<n) {
105         iws = max(ldWork*nb, IndexType(1));
106         if (lWork<iws) {
107             nb = lWork / ldWork;
108             nbMin = max(2, ilaenv<ElementType>(2"GETRI""", n));
109         }
110     } else {
111         iws = n;
112     }
113 
114     GeMatrixView<ElementType> Work(n, nb, work);
115 //
116 //  Solve the equation inv(A)*L = inv(U) for inv(A).
117 //
118     if (nb<nbMin || nb>=n) {
119 //
120 //      Use unblocked code.
121 //
122         for (IndexType j=n; j>=1; --j) {
123 //
124 //          Copy current column of L to WORK and replace with zeros.
125 //
126             work(_(j+1,n)) = A(_(j+1,n),j);
127             A(_(j+1,n),j)  = Zero;;
128 //
129 //          Compute current column of inv(A).
130 //
131             if (j<n) {
132                 blas::mv(NoTrans, -One,
133                          A(_,_(j+1,n)), work(_(j+1,n)),
134                          One,
135                          A(_,j));
136             }
137         }
138     } else {
139 //
140 //      Use blocked code.
141 //
142         const IndexType nn = ((n-1)/nb)*nb + 1;
143         for (IndexType j=nn; j>=1; j-=nb) {
144             const IndexType jb = min(nb, n-j+1);
145 //
146 //          Copy current block column of L to WORK and replace with
147 //          zeros.
148 //
149             for (IndexType jj=j, JJ=1; jj<=j+jb-1; ++jj, ++JJ) {
150                 Work(_(jj+1,n),JJ) = A(_(jj+1,n),jj);
151                 A(_(jj+1,n),jj)    = Zero;
152             }
153 //
154 //          Compute current block column of inv(A).
155 //
156             if (j+jb<=n) {
157                 blas::mm(NoTrans, NoTrans,
158                          -One,
159                          A(_,_(j+jb,n)),
160                          Work(_(j+jb,n),_(1,jb)),
161                          One,
162                          A(_,_(j,j+jb-1)));
163             }
164             blas::sm(Right, NoTrans,
165                      One, Work(_(j,j+jb-1),_(1,jb)).lowerUnit(),
166                      A(_,_(j,j+jb-1)));
167         }
168     }
169 //
170 //  Apply column interchanges.
171 //
172     for (IndexType j=n-1; j>=1; --j) {
173         const IndexType jp = piv(j);
174         if (jp!=j) {
175             blas::swap(A(_,j), A(_,jp));
176         }
177     }
178 
179     work(1) = iws;
180     return info;
181 }
182 
183 //-- (tr)tri
184 template <typename MA>
185 typename GeMatrix<MA>::IndexType
186 tri_generic(TrMatrix<MA> &A)
187 {
188     using std::min;
189     using cxxblas::getF77BlasChar;
190 
191     typedef typename TrMatrix<MA>::ElementType ElementType;
192     typedef typename TrMatrix<MA>::IndexType   IndexType;
193 
194     const Underscore<IndexType> _;
195 
196     const IndexType n = A.dim();
197     const bool upper  = (A.upLo()==Upper);
198     const bool noUnit = (A.diag()==NonUnit);
199 
200     const ElementType  Zero(0), One(1);
201 
202     IndexType info = 0;
203 //
204 //  Quick return if possible
205 //
206     if (n==0) {
207         return info;
208     }
209 //
210 //  Check for singularity if non-unit.
211 //
212     if (noUnit) {
213         for (info=1; info<=n; ++info) {
214             if (A(info,info)==Zero) {
215                 return info;
216             }
217         }
218         info = 0;
219     }
220 
221 //
222 //  Determine the block size for this environment.  //
223     const char upLoDiag[2] = { getF77BlasChar(A.upLo()),
224                                getF77BlasChar(A.diag()) };
225     const IndexType nb = ilaenv<ElementType>(1"TRTRI", upLoDiag, n);
226 
227     if (nb<=1 || nb>=n) {
228 //
229 //      Use unblocked code
230 //
231         info = ti2(A);
232     } else {
233 //
234 //      Use blocked code
235 //
236 
237         if (upper) {
238 //
239 //          Compute inverse of upper triangular matrix
240 //
241             for (IndexType j=1; j<=n; j+=nb) {
242                 const IndexType jb = min(nb, n-j+1);
243 //
244 //              Compute rows 1:j-1 of current block column
245 //
246                 const auto range1 = _(1,j-1);
247                 const auto range2 = _(j,j+jb-1);
248 
249                 const auto U11 = (noUnit) ? A(range1, range1).upper()
250                                           : A(range1, range1).upperUnit();
251                 auto U22 = (noUnit) ? A(range2, range2).upper()
252                                     : A(range2, range2).upperUnit();
253                 auto U12 = A(range1, range2);
254 
255                 blas::mm(Left, NoTrans, One, U11, U12);
256                 blas::sm(Right, NoTrans, -One, U22, U12);
257 //
258 //              Compute inverse of current diagonal block
259 //
260                 info = ti2(U22);
261             }
262         } else {
263 //
264 //          Compute inverse of lower triangular matrix
265 //
266             const IndexType nn = ((n-1)/nb)*nb + 1;
267             for (IndexType j=nn; j>=1; j-=nb) {
268                 const IndexType jb = min(nb, n-j+1);
269 
270                 const auto range1 = _(j,j+jb-1);
271                 auto L11 = (noUnit) ? A(range1,range1).lower()
272                                     : A(range1,range1).lowerUnit();
273 
274                 if (j+jb<=n) {
275 //
276 //                  Compute rows j+jb:n of current block column
277 //
278                     const auto range2 = _(j+jb,n);
279 
280                     const auto L22 = (noUnit) ? A(range2,range2).lower()
281                                               : A(range2,range2).lowerUnit();
282 
283                     auto L21 = A(range2, range1);
284 
285                     blas::mm(Left, NoTrans, One, L22, L21);
286                     blas::sm(Right, NoTrans, -One, L11, L21);
287                 }
288 //
289 //              Compute inverse of current diagonal block
290 //
291                 info = ti2(L11);
292             }
293         }
294     }
295     return info;
296 }
297 
298 //== interface for native lapack ===============================================
299 
300 #ifdef CHECK_CXXLAPACK
301 //-- (ge)tri
302 template <typename MA, typename VP, typename VWORK>
303 typename GeMatrix<MA>::IndexType
304 tri_native(GeMatrix<MA>             &A,
305            const DenseVector<VP>    &piv,
306            DenseVector<VWORK>       &work)
307 {
308     typedef typename GeMatrix<MA>::ElementType  ElementType;
309     typedef typename GeMatrix<MA>::IndexType    IndexType;
310 
311     const INTEGER    N      = A.numRows();
312     const INTEGER    LDA    = A.leadingDimension();
313     const INTEGER    LWORK  = work.length();
314     INTEGER          INFO;
315 
316     if (IsSame<ElementType, DOUBLE>::value) {
317         LAPACK_IMPL(dgetri)(&N,
318                             A.data(),
319                             &LDA,
320                             piv.data(),
321                             work.data(),
322                             &LWORK,
323                             &INFO);
324     } else {
325         ASSERT(0);
326     }
327     ASSERT(INFO>=0);
328     return INFO;
329 }
330 
331 //-- (tr)tri
332 template <typename MA>
333 typename GeMatrix<MA>::IndexType
334 tri_native(TrMatrix<MA> &A)
335 {
336     typedef typename GeMatrix<MA>::ElementType  ElementType;
337 
338     const char       UPLO   = getF77BlasChar(A.upLo());
339     const char       DIAG   = getF77BlasChar(A.diag());
340     const INTEGER    N      = A.dim();
341     const INTEGER    LDA    = A.leadingDimension();
342     INTEGER          INFO;
343 
344     if (IsSame<ElementType, DOUBLE>::value) {
345         LAPACK_IMPL(dtrtri)(&UPLO,
346                             &DIAG,
347                             &N,
348                             A.data(),
349                             &LDA,
350                             &INFO);
351     } else {
352         ASSERT(0);
353     }
354     ASSERT(INFO>=0);
355     return INFO;
356 }
357 
358 #endif // CHECK_CXXLAPACK
359 
360 //== public interface ==========================================================
361 
362 //-- (ge)tri
363 template <typename MA, typename VP, typename VWORK>
364 typename GeMatrix<MA>::IndexType
365 tri(GeMatrix<MA> &A, const DenseVector<VP> &piv, DenseVector<VWORK> &work)
366 {
367     using std::max;
368 
369     typedef typename GeMatrix<MA>::IndexType IndexType;
370 //
371 //  Test the input parameters
372 //
373 #   ifndef NDEBUG
374     ASSERT(A.firstRow()==1);
375     ASSERT(A.firstCol()==1);
376     ASSERT(A.numRows()==A.numCols());
377 
378     const IndexType n = A.numRows();
379 
380     ASSERT(piv.firstIndex()==1);
381     ASSERT(piv.length()==n);
382 
383     const bool lQuery = (work.length()==0);
384     ASSERT(lQuery || work.length()>=n);
385 #   endif
386 
387 //
388 //  Make copies of output arguments
389 //
390 #   ifdef CHECK_CXXLAPACK
391     typename GeMatrix<MA>::NoView        A_org    = A;
392     typename DenseVector<VWORK>::NoView  work_org = work;
393 #   endif
394 
395 //
396 //  Call implementation
397 //
398     const IndexType info = tri_generic(A, piv, work);
399 
400 //
401 //  Compare results
402 //
403 #   ifdef CHECK_CXXLAPACK
404     typename GeMatrix<MA>::NoView        A_generic    = A;
405     typename DenseVector<VWORK>::NoView  work_generic = work;
406 
407     A    = A_org;
408     work = work_org;
409 
410     const IndexType _info = tri_native(A, piv, work);
411 
412     bool failed = false;
413     if (! isIdentical(A_generic, A, "A_generic""A")) {
414         std::cerr << "CXXLAPACK: A_generic = " << A_generic << std::endl;
415         std::cerr << "F77LAPACK: A = " << A << std::endl;
416         failed = true;
417     }
418 
419     if (! isIdentical(work_generic, work, "work_generic""work")) {
420         std::cerr << "CXXLAPACK: work_generic = " << work_generic << std::endl;
421         std::cerr << "F77LAPACK: work = " << work << std::endl;
422         failed = true;
423     }
424 
425     if (! isIdentical(info, _info, " info""_info")) {
426         std::cerr << "CXXLAPACK:  info = " << info << std::endl;
427         std::cerr << "F77LAPACK: _info = " << _info << std::endl;
428         failed = true;
429     }
430 
431     if (failed) {
432         ASSERT(0);
433     } else {
434         // std::cerr << "passed: (ge)tri.tcc" << std::endl;
435     }
436 #   endif
437 
438     return info;
439 }
440 
441 //-- (tr)tri
442 template <typename MA>
443 typename GeMatrix<MA>::IndexType
444 tri(TrMatrix<MA> &A)
445 {
446     typedef typename GeMatrix<MA>::IndexType IndexType;
447 
448 //
449 //  Test the input parameters
450 //
451 #   ifndef NDEBUG
452     ASSERT(A.firstRow()==1);
453     ASSERT(A.firstCol()==1);
454 #   endif
455 
456 //
457 //  Make copies of output arguments
458 //
459 #   ifdef CHECK_CXXLAPACK
460     typename TrMatrix<MA>::NoView   A_org = A;
461 #   endif
462 
463 //
464 //  Call implementation
465 //
466     const IndexType info = tri_generic(A);
467 
468 //
469 //  Compare results
470 //
471 #   ifdef CHECK_CXXLAPACK
472     typename TrMatrix<MA>::NoView   A_generic = A;
473 
474     A = A_org;
475 
476     const IndexType _info = tri_native(A);
477 
478     bool failed = false;
479     if (! isIdentical(A_generic, A, "A_generic""A")) {
480         std::cerr << "CXXLAPACK: A_generic = " << A_generic << std::endl;
481         std::cerr << "F77LAPACK: A = " << A << std::endl;
482         failed = true;
483     }
484 
485     if (! isIdentical(info, _info, " info""_info")) {
486         std::cerr << "CXXLAPACK:  info = " << info << std::endl;
487         std::cerr << "F77LAPACK: _info = " << _info << std::endl;
488         failed = true;
489     }
490 
491     if (failed) {
492         ASSERT(0);
493     } else {
494         // std::cerr << "passed: (tr)tri.tcc" << std::endl;
495     }
496 #   endif
497 
498     return info;
499 }
500 
501 //-- forwarding ----------------------------------------------------------------
502 template <typename MA, typename VP, typename VWORK>
503 typename MA::IndexType
504 tri(MA &&A, const VP &&piv, VWORK &&work)
505 {
506     typedef typename MA::IndexType  IndexType;
507 
508     CHECKPOINT_ENTER;
509     IndexType info = tri(A, piv, work);
510     CHECKPOINT_LEAVE;
511 
512     return info;
513 }
514 
515 template <typename MA>
516 typename MA::IndexType
517 tri(MA &&A)
518 {
519     typedef typename MA::IndexType  IndexType;
520 
521     CHECKPOINT_ENTER;
522     IndexType info = tri(A);
523     CHECKPOINT_LEAVE;
524 
525     return info;
526 }
527 
528 } } // namespace lapack, flens
529 
530 #endif // FLENS_LAPACK_GESV_TRI_TCC