1 /*
  2  *   Copyright (c) 2011, Michael Lehn
  3  *
  4  *   All rights reserved.
  5  *
  6  *   Redistribution and use in source and binary forms, with or without
  7  *   modification, are permitted provided that the following conditions
  8  *   are met:
  9  *
 10  *   1) Redistributions of source code must retain the above copyright
 11  *      notice, this list of conditions and the following disclaimer.
 12  *   2) Redistributions in binary form must reproduce the above copyright
 13  *      notice, this list of conditions and the following disclaimer in
 14  *      the documentation and/or other materials provided with the
 15  *      distribution.
 16  *   3) Neither the name of the FLENS development group nor the names of
 17  *      its contributors may be used to endorse or promote products derived
 18  *      from this software without specific prior written permission.
 19  *
 20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 21  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 22  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 23  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 24  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 26  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 27  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 28  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 29  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 /* Baesed on
 34  *
 35       SUBROUTINE DGETRI( N, A, LDA, IPIV, WORK, LWORK, INFO )
 36 
 37       SUBROUTINE DTRTRI( UPLO, DIAG, N, A, LDA, INFO )
 38  *
 39  *  -- LAPACK routine (version 3.2) --
 40  *  -- LAPACK is a software package provided by Univ. of Tennessee,    --
 41  *  -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
 42  *     November 2006
 43  */
 44 
 45 #ifndef FLENS_LAPACK_GESV_TRI_TCC
 46 #define FLENS_LAPACK_GESV_TRI_TCC 1
 47 
 48 #include <algorithm>
 49 #include <flens/blas/blas.h>
 50 #include <flens/lapack/lapack.h>
 51 
 52 #include <flens/lapack/interface/include/f77lapack.h>
 53 
 54 namespace flens { namespace lapack {
 55 
 56 //== generic lapack implementation =============================================
 57 //-- (ge)tri
 58 template <typename MA, typename VP, typename VWORK>
 59 typename GeMatrix<MA>::IndexType
 60 tri_generic(GeMatrix<MA> &A, DenseVector<VP> &piv, DenseVector<VWORK> &work)
 61 {
 62     using std::max;
 63 
 64     typedef typename GeMatrix<MA>::ElementType ElementType;
 65     typedef typename GeMatrix<MA>::IndexType   IndexType;
 66 
 67     const ElementType Zero(0), One(1);
 68     const IndexType n = A.numRows();
 69     const Underscore<IndexType> _;
 70 
 71     IndexType info = 0;
 72     IndexType nb = ilaenv<ElementType>(1"GETRI""", n);
 73 
 74     const IndexType lWorkOpt = n*nb;
 75 
 76     if (work.length()==0) {
 77         work.resize(max(lWorkOpt, IndexType(1)));
 78     }
 79     work(1) = lWorkOpt;
 80 
 81 //
 82 //  Quick return if possible
 83 //
 84     if (n==0) {
 85         return info;
 86     }
 87 //
 88 //  Form inv(U).  If INFO > 0 from DTRTRI, then U is singular,
 89 //  and the inverse is not computed.
 90 //
 91     info = tri(A.upper());
 92     if (info>0) {
 93         return info;
 94     }
 95 
 96     IndexType nbMin = 2;
 97     const IndexType lWork  = work.length();;
 98     const IndexType ldWork = n;
 99 
100     IndexType iws;
101 
102     if (nb>1 && nb<n) {
103         iws = max(ldWork*nb, IndexType(1));
104         if (lWork<iws) {
105             nb = lWork / ldWork;
106             nbMin = max(2, ilaenv<ElementType>(2"GETRI""", n));
107         }
108     } else {
109         iws = n;
110     }
111 
112     GeMatrixView<ElementType> Work(n, nb, work);
113 //
114 //  Solve the equation inv(A)*L = inv(U) for inv(A).
115 //
116     if (nb<nbMin || nb>=n) {
117 //
118 //      Use unblocked code.
119 //
120         for (IndexType j=n; j>=1; --j) {
121 //
122 //          Copy current column of L to WORK and replace with zeros.
123 //
124             work(_(j+1,n)) = A(_(j+1,n),j);
125             A(_(j+1,n),j)  = Zero;;
126 //
127 //          Compute current column of inv(A).
128 //
129             if (j<n) {
130                 blas::mv(NoTrans, -One,
131                          A(_,_(j+1,n)), work(_(j+1,n)),
132                          One,
133                          A(_,j));
134             }
135         }
136     } else {
137 //
138 //      Use blocked code.
139 //
140         const IndexType nn = ((n-1)/nb)*nb + 1;
141         for (IndexType j=nn; j>=1; j-=nb) {
142             const IndexType jb = min(nb, n-j+1);
143 //
144 //          Copy current block column of L to WORK and replace with
145 //          zeros.
146 //
147             for (IndexType jj=j, JJ=1; jj<=j+jb-1; ++jj, ++JJ) {
148                 Work(_(jj+1,n),JJ) = A(_(jj+1,n),jj);
149                 A(_(jj+1,n),jj)    = Zero;
150             }
151 //
152 //          Compute current block column of inv(A).
153 //
154             if (j+jb<=n) {
155                 blas::mm(NoTrans, NoTrans,
156                          -One,
157                          A(_,_(j+jb,n)),
158                          Work(_(j+jb,n),_(1,jb)),
159                          One,
160                          A(_,_(j,j+jb-1)));
161             }
162             blas::sm(Right, NoTrans,
163                      One, Work(_(j,j+jb-1),_(1,jb)).lowerUnit(),
164                      A(_,_(j,j+jb-1)));
165         }
166     }
167 //
168 //  Apply column interchanges.
169 //
170     for (IndexType j=n-1; j>=1; --j) {
171         const IndexType jp = piv(j);
172         if (jp!=j) {
173             blas::swap(A(_,j), A(_,jp));
174         }
175     }
176 
177     work(1) = iws;
178     return info;
179 }
180 
181 //-- (tr)tri
182 template <typename MA>
183 typename GeMatrix<MA>::IndexType
184 tri_generic(TrMatrix<MA> &A)
185 {
186     using std::min;
187     using cxxblas::getF77BlasChar;
188 
189     typedef typename TrMatrix<MA>::ElementType ElementType;
190     typedef typename TrMatrix<MA>::IndexType   IndexType;
191 
192     const Underscore<IndexType> _;
193 
194     const IndexType n = A.dim();
195     const bool upper  = (A.upLo()==Upper);
196     const bool noUnit = (A.diag()==NonUnit);
197 
198     const ElementType  Zero(0), One(1);
199 
200     IndexType info = 0;
201 //
202 //  Quick return if possible
203 //
204     if (n==0) {
205         return info;
206     }
207 //
208 //  Check for singularity if non-unit.
209 //
210     if (noUnit) {
211         for (info=1; info<=n; ++info) {
212             if (A(info,info)==Zero) {
213                 return info;
214             }
215         }
216         info = 0;
217     }
218 
219 //
220 //  Determine the block size for this environment.  //
221     const char upLoDiag[2] = { getF77BlasChar(A.upLo()),
222                                getF77BlasChar(A.diag()) };
223     const IndexType nb = ilaenv<ElementType>(1"TRTRI", upLoDiag, n);
224 
225     if (nb<=1 || nb>=n) {
226 //
227 //      Use unblocked code
228 //
229         info = ti2(A);
230     } else {
231 //
232 //      Use blocked code
233 //
234 
235         if (upper) {
236 //
237 //          Compute inverse of upper triangular matrix
238 //
239             for (IndexType j=1; j<=n; j+=nb) {
240                 const IndexType jb = min(nb, n-j+1);
241 //
242 //              Compute rows 1:j-1 of current block column
243 //
244                 const auto range1 = _(1,j-1);
245                 const auto range2 = _(j,j+jb-1);
246 
247                 const auto U11 = (noUnit) ? A(range1, range1).upper()
248                                           : A(range1, range1).upperUnit();
249                 auto U22 = (noUnit) ? A(range2, range2).upper()
250                                     : A(range2, range2).upperUnit();
251                 auto U12 = A(range1, range2);
252 
253                 blas::mm(Left, NoTrans, One, U11, U12);
254                 blas::sm(Right, NoTrans, -One, U22, U12);
255 //
256 //              Compute inverse of current diagonal block
257 //
258                 info = ti2(U22);
259             }
260         } else {
261 //
262 //          Compute inverse of lower triangular matrix
263 //
264             const IndexType nn = ((n-1)/nb)*nb + 1;
265             for (IndexType j=nn; j>=1; j-=nb) {
266                 const IndexType jb = min(nb, n-j+1);
267 
268                 const auto range1 = _(j,j+jb-1);
269                 auto L11 = (noUnit) ? A(range1,range1).lower()
270                                     : A(range1,range1).lowerUnit();
271 
272                 if (j+jb<=n) {
273 //
274 //                  Compute rows j+jb:n of current block column
275 //
276                     const auto range2 = _(j+jb,n);
277 
278                     const auto L22 = (noUnit) ? A(range2,range2).lower()
279                                               : A(range2,range2).lowerUnit();
280 
281                     auto L21 = A(range2, range1);
282 
283                     blas::mm(Left, NoTrans, One, L22, L21);
284                     blas::sm(Right, NoTrans, -One, L11, L21);
285                 }
286 //
287 //              Compute inverse of current diagonal block
288 //
289                 info = ti2(L11);
290             }
291         }
292     }
293     return info;
294 }
295 
296 //== interface for native lapack ===============================================
297 
298 #ifdef CHECK_CXXLAPACK
299 //-- (ge)tri
300 template <typename MA, typename VP, typename VWORK>
301 typename GeMatrix<MA>::IndexType
302 tri_native(GeMatrix<MA> &A, DenseVector<VP> &piv, DenseVector<VWORK> &work)
303 {
304     typedef typename GeMatrix<MA>::ElementType  ElementType;
305     typedef typename GeMatrix<MA>::IndexType    IndexType;
306 
307     const INTEGER    N      = A.numRows();
308     const INTEGER    LDA    = A.leadingDimension();
309     const INTEGER    LWORK  = work.length();
310     INTEGER          INFO;
311 
312     if (IsSame<ElementType, DOUBLE>::value) {
313         LAPACK_IMPL(dgetri)(&N,
314                             A.data(),
315                             &LDA,
316                             piv.data(),
317                             work.data(),
318                             &LWORK,
319                             &INFO);
320     } else {
321         ASSERT(0);
322     }
323     ASSERT(INFO>=0);
324     return INFO;
325 }
326 
327 //-- (tr)tri
328 template <typename MA>
329 typename GeMatrix<MA>::IndexType
330 tri_native(TrMatrix<MA> &A)
331 {
332     typedef typename GeMatrix<MA>::ElementType  ElementType;
333 
334     const char       UPLO   = getF77BlasChar(A.upLo());
335     const char       DIAG   = getF77BlasChar(A.diag());
336     const INTEGER    N      = A.dim();
337     const INTEGER    LDA    = A.leadingDimension();
338     INTEGER          INFO;
339 
340     if (IsSame<ElementType, DOUBLE>::value) {
341         LAPACK_IMPL(dtrtri)(&UPLO,
342                             &DIAG,
343                             &N,
344                             A.data(),
345                             &LDA,
346                             &INFO);
347     } else {
348         ASSERT(0);
349     }
350     ASSERT(INFO>=0);
351     return INFO;
352 }
353 
354 #endif // CHECK_CXXLAPACK
355 
356 //== public interface ==========================================================
357 
358 //-- (ge)tri
359 template <typename MA, typename VP, typename VWORK>
360 typename GeMatrix<MA>::IndexType
361 tri(GeMatrix<MA> &A, DenseVector<VP> &piv, DenseVector<VWORK> &work)
362 {
363     using std::max;
364 
365     typedef typename GeMatrix<MA>::IndexType IndexType;
366 //
367 //  Test the input parameters
368 //
369 #   ifndef NDEBUG
370     ASSERT(A.firstRow()==1);
371     ASSERT(A.firstCol()==1);
372     ASSERT(A.numRows()==A.numCols());
373 
374     const IndexType n = A.numRows();
375 
376     ASSERT(piv.firstIndex()==1);
377     ASSERT(piv.length()==n);
378 
379     const bool lQuery = (work.length()==0);
380     ASSERT(lQuery || work.length()>=n);
381 #   endif
382 
383 //
384 //  Make copies of output arguments
385 //
386 #   ifdef CHECK_CXXLAPACK
387     typename GeMatrix<MA>::NoView        A_org    = A;
388     typename DenseVector<VP>::NoView     piv_org  = piv;
389     typename DenseVector<VWORK>::NoView  work_org = work;
390 #   endif
391 
392 //
393 //  Call implementation
394 //
395     const IndexType info = tri_generic(A, piv, work);
396 
397 //
398 //  Compare results
399 //
400 #   ifdef CHECK_CXXLAPACK
401     typename GeMatrix<MA>::NoView        A_generic    = A;
402     typename DenseVector<VP>::NoView     piv_generic  = piv;
403     typename DenseVector<VWORK>::NoView  work_generic = work;
404 
405     A    = A_org;
406     piv  = piv_org;
407     work = work_org;
408 
409     const IndexType _info = tri_native(A, piv, work);
410 
411     bool failed = false;
412     if (! isIdentical(A_generic, A, "A_generic""A")) {
413         std::cerr << "CXXLAPACK: A_generic = " << A_generic << std::endl;
414         std::cerr << "F77LAPACK: A = " << A << std::endl;
415         failed = true;
416     }
417 
418     if (! isIdentical(piv_generic, piv, "piv_generic""piv")) {
419         std::cerr << "CXXLAPACK: piv_generic = " << piv_generic << std::endl;
420         std::cerr << "F77LAPACK: piv = " << piv << std::endl;
421         failed = true;
422     }
423 
424     if (! isIdentical(work_generic, work, "work_generic""work")) {
425         std::cerr << "CXXLAPACK: work_generic = " << work_generic << std::endl;
426         std::cerr << "F77LAPACK: work = " << work << std::endl;
427         failed = true;
428     }
429 
430     if (! isIdentical(info, _info, " info""_info")) {
431         std::cerr << "CXXLAPACK:  info = " << info << std::endl;
432         std::cerr << "F77LAPACK: _info = " << _info << std::endl;
433         failed = true;
434     }
435 
436     if (failed) {
437         ASSERT(0);
438     } else {
439         // std::cerr << "passed: (ge)tri.tcc" << std::endl;
440     }
441 #   endif
442 
443     return info;
444 }
445 
446 //-- (tr)tri
447 template <typename MA>
448 typename GeMatrix<MA>::IndexType
449 tri(TrMatrix<MA> &A)
450 {
451     typedef typename GeMatrix<MA>::IndexType IndexType;
452 
453 //
454 //  Test the input parameters
455 //
456 #   ifndef NDEBUG
457     ASSERT(A.firstRow()==1);
458     ASSERT(A.firstCol()==1);
459 #   endif
460 
461 //
462 //  Make copies of output arguments
463 //
464 #   ifdef CHECK_CXXLAPACK
465     typename TrMatrix<MA>::NoView   A_org = A;
466 #   endif
467 
468 //
469 //  Call implementation
470 //
471     const IndexType info = tri_generic(A);
472 
473 //
474 //  Compare results
475 //
476 #   ifdef CHECK_CXXLAPACK
477     typename TrMatrix<MA>::NoView   A_generic = A;
478 
479     A = A_org;
480 
481     const IndexType _info = tri_native(A);
482 
483     bool failed = false;
484     if (! isIdentical(A_generic, A, "A_generic""A")) {
485         std::cerr << "CXXLAPACK: A_generic = " << A_generic << std::endl;
486         std::cerr << "F77LAPACK: A = " << A << std::endl;
487         failed = true;
488     }
489 
490     if (! isIdentical(info, _info, " info""_info")) {
491         std::cerr << "CXXLAPACK:  info = " << info << std::endl;
492         std::cerr << "F77LAPACK: _info = " << _info << std::endl;
493         failed = true;
494     }
495 
496     if (failed) {
497         ASSERT(0);
498     } else {
499         // std::cerr << "passed: (tr)tri.tcc" << std::endl;
500     }
501 #   endif
502 
503     return info;
504 }
505 
506 //-- forwarding ----------------------------------------------------------------
507 template <typename MA, typename VP, typename VWORK>
508 typename MA::IndexType
509 tri(MA &&A, VP &&piv, VWORK &&work)
510 {
511     typedef typename MA::IndexType  IndexType;
512 
513     CHECKPOINT_ENTER;
514     IndexType info = tri(A, piv, work);
515     CHECKPOINT_LEAVE;
516 
517     return info;
518 }
519 
520 template <typename MA>
521 typename MA::IndexType
522 tri(MA &&A)
523 {
524     typedef typename MA::IndexType  IndexType;
525 
526     CHECKPOINT_ENTER;
527     IndexType info = tri(A);
528     CHECKPOINT_LEAVE;
529 
530     return info;
531 }
532 
533 } } // namespace lapack, flens
534 
535 #endif // FLENS_LAPACK_GESV_TRI_TCC