1 /*
  2  *   Copyright (c) 2011, Michael Lehn
  3  *
  4  *   All rights reserved.
  5  *
  6  *   Redistribution and use in source and binary forms, with or without
  7  *   modification, are permitted provided that the following conditions
  8  *   are met:
  9  *
 10  *   1) Redistributions of source code must retain the above copyright
 11  *      notice, this list of conditions and the following disclaimer.
 12  *   2) Redistributions in binary form must reproduce the above copyright
 13  *      notice, this list of conditions and the following disclaimer in
 14  *      the documentation and/or other materials provided with the
 15  *      distribution.
 16  *   3) Neither the name of the FLENS development group nor the names of
 17  *      its contributors may be used to endorse or promote products derived
 18  *      from this software without specific prior written permission.
 19  *
 20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 21  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 22  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 23  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 24  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 26  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 27  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 28  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 29  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 /* Based on
 34  *
 35       SUBROUTINE DORMQR( SIDE, TRANS, M, N, K, A, LDA, TAU, C, LDC,
 36      $                   WORK, LWORK, INFO )
 37  *
 38  *  -- LAPACK routine (version 3.3.1) --
 39  *  -- LAPACK is a software package provided by Univ. of Tennessee,    --
 40  *  -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
 41  *  -- April 2011                                                      --
 42  */
 43 
 44 #ifndef FLENS_LAPACK_QR_ORMQR_TCC
 45 #define FLENS_LAPACK_QR_ORMQR_TCC 1
 46 
 47 #include <flens/blas/blas.h>
 48 #include <flens/lapack/lapack.h>
 49 
 50 namespace flens { namespace lapack {
 51 
 52 //== generic lapack implementation =============================================
 53 
 54 template <typename MA, typename MC>
 55 typename GeMatrix<MC>::IndexType
 56 ormqr_generic_wsq(Side              side,
 57                   Transpose         trans,
 58                   GeMatrix<MA>      &A,
 59                   GeMatrix<MC>      &C)
 60 {
 61     using std::max;
 62     using std::min;
 63 
 64     typedef typename GeMatrix<MC>::ElementType  T;
 65     typedef typename GeMatrix<MC>::IndexType    IndexType;
 66 
 67     typedef typename GeMatrix<MC>::View         GeView;
 68     typedef typename GeView::Engine             GeViewEngine;
 69 
 70 //
 71 //  Paramter for maximum block size and buffer for TrMatrix Tr.
 72 //
 73     const IndexType nbMax   = 64;
 74 
 75     const IndexType m = C.numRows();
 76     const IndexType n = C.numCols();
 77     const IndexType k = A.numCols();
 78 
 79     const IndexType nw      = (side==Left) ? n : m;
 80 
 81 //
 82 //  Determine the block size.  nb may be at most nbMax, where nbMax
 83 //  is used to define the local array tr.
 84 //
 85     char opt[3];
 86     opt[0] = (side==Left) ? 'L' : 'R';
 87     if (trans==NoTrans) {
 88         opt[1] = 'N';
 89     } else if (trans==Conj) {
 90         opt[1] = 'R';
 91     } else if (trans==Trans) {
 92         opt[1] = 'T';
 93     } else if (trans==ConjTrans) {
 94         opt[1] = 'C';
 95     }
 96     opt[2] = 0;
 97 
 98     IndexType nb = min(nbMax, IndexType(ilaenv<T>(1"ORMQR", opt, m, n, k)));
 99     return max(IndexType(1), nw)*nb;
100 }
101 
102 template <typename MA, typename VTAU, typename MC, typename VWORK>
103 void
104 ormqr_generic(Side                      side,
105               Transpose                 trans,
106               GeMatrix<MA>              &A,
107               const DenseVector<VTAU>   &tau,
108               GeMatrix<MC>              &C,
109               DenseVector<VWORK>        &work)
110 {
111     using std::max;
112     using std::min;
113 
114     typedef typename GeMatrix<MC>::ElementType  T;
115     typedef typename GeMatrix<MC>::IndexType    IndexType;
116 
117     typedef typename GeMatrix<MC>::View         GeView;
118     typedef typename GeView::Engine             GeViewEngine;
119 
120     const Underscore<IndexType> _;
121 
122 
123 //
124 //  Paramter for maximum block size and buffer for TrMatrix Tr.
125 //
126     const IndexType     nbMax = 64;
127     const IndexType     ldt = nbMax + 1;
128     T                   trBuffer[nbMax*ldt];
129 
130     const IndexType m = C.numRows();
131     const IndexType n = C.numCols();
132     const IndexType k = A.numCols();
133 
134     ASSERT(tau.length()==k);
135 
136     const bool noTrans = ((trans==Trans) || (trans==ConjTrans)) ? false
137                                                                 : true;
138 //
139 //  nq is the order of Q and nw is the minimum dimension of work
140 //
141     IndexType nq, nw;
142     if (side==Left) {
143         nq = m;
144         nw = n;
145     } else {
146         nq = n;
147         nw = m;
148     }
149 //
150 //  Determine the block size.  nb may be at most nbMax, where nbMax
151 //  is used to define the local array tr.
152 //
153     char opt[3];
154     opt[0] = (side==Left) ? 'L' : 'R';
155     if (trans==NoTrans) {
156         opt[1] = 'N';
157     } else if (trans==Conj) {
158         opt[1] = 'R';
159     } else if (trans==Trans) {
160         opt[1] = 'T';
161     } else if (trans==ConjTrans) {
162         opt[1] = 'C';
163     }
164     opt[2] = 0;
165 
166     IndexType nb = min(nbMax, IndexType(ilaenv<T>(1"ORMQR", opt, m, n, k)));
167     IndexType lWorkOpt = max(IndexType(1), nw)*nb;
168 
169     if (work.length()==0) {
170         work.resize(lWorkOpt);
171     }
172 
173 //
174 //  Quick return if possible
175 //
176     if ((m==0) || (n==0) || (k==0)) {
177         work(1) = 1;
178         return;
179     }
180 
181     IndexType nbMin = 2;
182     IndexType iws;
183     if ((nb>1) && (nb<k)) {
184         iws = lWorkOpt;
185         if (work.length()<iws) {
186             nb = work.length()/nw;
187             nbMin = max(nbMin, IndexType(ilaenv<T>(2"ORMQR", opt, m, n, k)));
188         }
189     } else {
190         iws = nw;
191     }
192 
193     if ((nb<nbMin) || (nb>=k)) {
194 //
195 //      Use unblocked code
196 //
197         auto _work = (side==Left) ? work(_(1,n)) : work(_(1,m));
198         orm2r(side, trans, A, tau, C, _work);
199     } else {
200 //
201 //      Use blocked code
202 //
203         IndexType iBeg, iInc, iEnd;
204         if ((side==Left && !noTrans) || (side==Right && noTrans)) {
205             iBeg = 1;
206             iEnd = ((k-1)/nb)*nb + 1;
207             iInc = nb;
208         } else {
209             iBeg = ((k-1)/nb)*nb + 1;
210             iEnd = 1;
211             iInc = -nb;
212         }
213         iEnd += iInc;
214 
215         IndexType ic, jc;
216         if (side==Left) {
217             jc = 1;
218         } else {
219             ic = 1;
220         }
221 
222         typename GeMatrix<MA>::View Work(nw, nb, work);
223 
224         for (IndexType i=iBeg; i!=iEnd; i+=iInc) {
225             const IndexType ib = min(nb, k-i+1);
226             GeView          Tr = GeViewEngine(ib, ib, trBuffer, ldt);
227 
228 //
229 //          Form the triangular factor of the block reflector
230 //          H = H(i) H(i+1) . . . H(i+ib-1)
231 //
232             larft(Forward, ColumnWise, nq-i+1,
233                   A(_(i,nq),_(i,i+ib-1)), tau(_(i,i+ib-1)), Tr.upper());
234 
235             if (side==Left) {
236 //
237 //              H or H**T is applied to C(i:m,1:n)
238 //
239                 ic = i;
240             } else {
241 //
242 //              H or H**T is applied to C(1:m,i:n)
243 //
244                 jc = i;
245             }
246 
247             larfb(side, trans,
248                   Forward, ColumnWise,
249                   A(_(i,nq),_(i,i+ib-1)), Tr.upper(),
250                   C(_(ic,m),_(jc,n)),
251                   Work(_,_(1,ib)));
252         }
253     }
254     work(1) = lWorkOpt;
255 }
256 
257 //== interface for native lapack ===============================================
258 
259 #ifdef CHECK_CXXLAPACK
260 
261 template <typename MA, typename MC>
262 typename GeMatrix<MC>::IndexType
263 ormqr_native_wsq(Side              side,
264                  Transpose         trans,
265                  GeMatrix<MA>      &A,
266                  GeMatrix<MC>      &C)
267 {
268     typedef typename GeMatrix<MC>::ElementType  T;
269 
270     const char      SIDE    = getF77LapackChar(side);
271     const char      TRANS   = getF77LapackChar(trans);
272     const INTEGER   M       = C.numRows();
273     const INTEGER   N       = C.numCols();
274     const INTEGER   K       = A.numCols();
275     const INTEGER   LDA     = A.leadingDimension();
276     const INTEGER   LDC     = C.leadingDimension();
277     T               WORK, DUMMY;
278     const INTEGER   LWORK   = -1;
279     INTEGER         INFO;
280 
281     if (IsSame<T,DOUBLE>::value) {
282         LAPACK_IMPL(dormqr)(&SIDE,
283                             &TRANS,
284                             &M,
285                             &N,
286                             &K,
287                             A.data(),
288                             &LDA,
289                             &DUMMY,
290                             C.data(),
291                             &LDC,
292                             &WORK,
293                             &LWORK,
294                             &INFO);
295     } else {
296         ASSERT(0);
297     }
298 
299     ASSERT(INFO>=0);
300     return WORK;
301 }
302 
303 template <typename MA, typename VTAU, typename MC, typename VWORK>
304 void
305 ormqr_native(Side                       side,
306              Transpose                  trans,
307              GeMatrix<MA>               &A,
308              const DenseVector<VTAU>    &tau,
309              GeMatrix<MC>               &C,
310              DenseVector<VWORK>         &work)
311 {
312     typedef typename GeMatrix<MC>::ElementType  T;
313 
314     const char      SIDE    = getF77LapackChar(side);
315     const char      TRANS   = getF77LapackChar(trans);
316     const INTEGER   M       = C.numRows();
317     const INTEGER   N       = C.numCols();
318     const INTEGER   K       = A.numCols();
319     const INTEGER   LDA     = A.leadingDimension();
320     const INTEGER   LDC     = C.leadingDimension();
321     const INTEGER   LWORK   = work.length();
322     INTEGER         INFO;
323 
324     if (IsSame<T,DOUBLE>::value) {
325         LAPACK_IMPL(dormqr)(&SIDE,
326                             &TRANS,
327                             &M,
328                             &N,
329                             &K,
330                             A.data(),
331                             &LDA,
332                             tau.data(),
333                             C.data(),
334                             &LDC,
335                             work.data(),
336                             &LWORK,
337                             &INFO);
338     } else {
339         ASSERT(0);
340     }
341 
342     ASSERT(INFO>=0);
343 }
344 
345 #endif // CHECK_CXXLAPACK
346 
347 //== public interface ==========================================================
348 
349 template <typename MA, typename MC>
350 typename GeMatrix<MC>::IndexType
351 ormqr_wsq(Side              side,
352           Transpose         trans,
353           GeMatrix<MA>      &A,
354           GeMatrix<MC>      &C)
355 {
356     typedef typename GeMatrix<MC>::IndexType    IndexType;
357 
358 //
359 //  Test the input parameters
360 //
361 #   ifndef NDEBUG
362     const IndexType m = C.numRows();
363     const IndexType n = C.numCols();
364     const IndexType k = A.numCols();
365 
366     if (side==Left) {
367         ASSERT(A.numRows()==m);
368     } else {
369         ASSERT(A.numCols()==n);
370     }
371 #   endif
372 
373 //
374 //  Call implementation
375 //
376     const IndexType info = ormqr_generic(side, trans, A, C);
377 
378 #   ifdef CHECK_CXXLAPACK
379 //
380 //  Compare generic results with results from the native implementation
381 //
382     const IndexType _info = ormqr_native(side, trans, A, C);
383 
384     ASSERT(info==_info);
385 #   endif
386     return info;
387 }
388 
389 template <typename MA, typename VTAU, typename MC, typename VWORK>
390 void
391 ormqr(Side                      side,
392       Transpose                 trans,
393       GeMatrix<MA>              &A,
394       const DenseVector<VTAU>   &tau,
395       GeMatrix<MC>              &C,
396       DenseVector<VWORK>        &work)
397 {
398     typedef typename GeMatrix<MC>::IndexType    IndexType;
399 
400 //
401 //  Test the input parameters
402 //
403 #   ifndef NDEBUG
404     const IndexType m = C.numRows();
405     const IndexType n = C.numCols();
406     const IndexType k = A.numCols();
407 
408     ASSERT(tau.length()==k);
409 
410     if (side==Left) {
411         ASSERT(A.numRows()==m);
412     } else {
413         ASSERT(A.numRows()==n);
414     }
415 
416     if (work.length()>0) {
417         if (side==Left) {
418             ASSERT(work.length()>=n);
419         } else {
420             ASSERT(work.length()>=m);
421         }
422     }
423 #   endif
424 
425 //
426 //  Make copies of output arguments
427 //
428 #   ifdef CHECK_CXXLAPACK
429     typename GeMatrix<MA>::NoView       A_org      = A;
430     typename GeMatrix<MC>::NoView       C_org      = C;
431     typename DenseVector<VWORK>::NoView work_org   = work;
432 #   endif
433 
434 //
435 //  Call implementation
436 //
437     ormqr_generic(side, trans, A, tau, C, work);
438 
439 #   ifdef CHECK_CXXLAPACK
440 //
441 //  Make copies of results computed by the generic implementation
442 //
443     typename GeMatrix<MA>::NoView       A_generic       = A;
444     typename GeMatrix<MC>::NoView       C_generic       = C;
445     typename DenseVector<VWORK>::NoView work_generic    = work;
446 
447 //
448 //  restore output arguments
449 //
450     A = A_org;
451     C = C_org;
452     work = work_org;
453 
454 //
455 //  Compare generic results with results from the native implementation
456 //
457     ormqr_native(side, trans, A, tau, C, work);
458 
459     bool failed = false;
460     if (! isIdentical(A_generic, A, "A_generic""A")) {
461         std::cerr << "CXXLAPACK: A_generic = " << A_generic << std::endl;
462         std::cerr << "F77LAPACK: A = " << A << std::endl;
463         failed = true;
464     }
465     if (! isIdentical(C_generic, C, "C_generic""C")) {
466         std::cerr << "CXXLAPACK: C_generic = " << C_generic << std::endl;
467         std::cerr << "F77LAPACK: C = " << C << std::endl;
468         failed = true;
469     }
470     if (! isIdentical(work_generic, work, "work_generic""work")) {
471         std::cerr << "CXXLAPACK: work_generic = " << work_generic << std::endl;
472         std::cerr << "F77LAPACK: work = " << work << std::endl;
473         failed = true;
474     }
475 
476     if (failed) {
477         std::cerr << "error in: ormqr.tcc" << std::endl;
478         ASSERT(0);
479     } else {
480         // std::cerr << "passed: ormqr.tcc" << std::endl;
481     }
482 #   endif
483 }
484 
485 //-- forwarding ----------------------------------------------------------------
486 template <typename MA, typename MC>
487 typename MC::IndexType
488 ormqr_wsq(Side          side,
489           Transpose     trans,
490           MA            &&A,
491           MC            &&C)
492 {
493     typedef typename MC::IndexType IndexType;
494 
495     CHECKPOINT_ENTER;
496     const IndexType info = ormqr_wsq(side, trans, A, C);
497     CHECKPOINT_LEAVE;
498 
499     return info;
500 }
501 
502 template <typename MA, typename VTAU, typename MC, typename VWORK>
503 void
504 ormqr(Side              side,
505       Transpose         trans,
506       MA                &&A,
507       const VTAU        &tau,
508       MC                &&C,
509       VWORK             &&work)
510 {
511     CHECKPOINT_ENTER;
512     ormqr(side, trans, A, tau, C, work);
513     CHECKPOINT_LEAVE;
514 }
515 
516 } } // namespace lapack, flens
517 
518 #endif // FLENS_LAPACK_QR_ORMQR_TCC