1 /*
  2  *   Copyright (c) 2011, Michael Lehn
  3  *
  4  *   All rights reserved.
  5  *
  6  *   Redistribution and use in source and binary forms, with or without
  7  *   modification, are permitted provided that the following conditions
  8  *   are met:
  9  *
 10  *   1) Redistributions of source code must retain the above copyright
 11  *      notice, this list of conditions and the following disclaimer.
 12  *   2) Redistributions in binary form must reproduce the above copyright
 13  *      notice, this list of conditions and the following disclaimer in
 14  *      the documentation and/or other materials provided with the
 15  *      distribution.
 16  *   3) Neither the name of the FLENS development group nor the names of
 17  *      its contributors may be used to endorse or promote products derived
 18  *      from this software without specific prior written permission.
 19  *
 20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 21  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 22  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 23  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 24  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 26  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 27  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 28  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 29  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 /* Based on
 34  *
 35       SUBROUTINE DLARFB( SIDE, TRANS, DIRECT, STOREV, M, N, K, V, LDV,
 36      $                   T, LDT, C, LDC, WORK, LDWORK )
 37  *
 38  *  -- LAPACK auxiliary routine (version 3.3.1) --
 39  *  -- LAPACK is a software package provided by Univ. of Tennessee,    --
 40  *  -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
 41  *  -- April 2011                                                      --
 42  *
 43  */
 44 
 45 #ifndef FLENS_LAPACK_AUX_LARFB_TCC
 46 #define FLENS_LAPACK_AUX_LARFB_TCC 1
 47 
 48 #include <flens/blas/blas.h>
 49 #include <flens/lapack/lapack.h>
 50 
 51 namespace flens { namespace lapack {
 52 
 53 //== generic lapack implementation =============================================
 54 
 55 template <typename MV, typename MT, typename MC, typename MWORK>
 56 void
 57 larfb_generic(Side                  side,
 58               Transpose             transH,
 59               Direction             direction,
 60               StoreVectors          storeVectors,
 61               const GeMatrix<MV>    &V,
 62               const TrMatrix<MT>    &Tr,
 63               GeMatrix<MC>          &C,
 64               GeMatrix<MWORK>       &Work)
 65 {
 66     using lapack::ilalc;
 67     using lapack::ilalr;
 68     using std::max;
 69 
 70     typedef typename GeMatrix<MC>::ElementType  T;
 71     typedef typename GeMatrix<MC>::IndexType    IndexType;
 72 
 73     const Underscore<IndexType> _;
 74     const T                     One(1);
 75 
 76     const IndexType m = C.numRows();
 77     const IndexType n = C.numCols();
 78     const IndexType k = Tr.dim();
 79 
 80 //
 81 //  Quick return if possible
 82 //
 83     if ((m==0) || (n==0)) {
 84         return;
 85     }
 86 
 87     const Transpose transT = (transH==NoTrans) ? ConjTrans : NoTrans;
 88 
 89     if (storeVectors==ColumnWise) {
 90         if (direction==Forward) {
 91 //
 92 //          Let  V =  ( V1 )    (first K rows)
 93 //                    ( V2 )
 94 //          where  V1  is unit lower triangular.
 95 //
 96             if (side==Left) {
 97 //
 98 //              Form  H * C  or  H**T * C  where  C = ( C1 )
 99 //                                                    ( C2 )
100 //
101                 const IndexType lastV = max(k, ilalr(V));
102                 const auto V1 = V(_(1,k),_);
103                 const auto V2 = V(_(k+1,lastV),_);
104 
105                 const IndexType lastC = ilalc(C(_(1,lastV),_));
106                 auto C1 = C(_(1,k),_(1,lastC));
107                 auto C2 = C(_(k+1,lastV),_(1,lastC));
108 
109 //
110 //              W := C**T * V  =  (C1**T * V1 + C2**T * V2)  (stored in WORK)
111 //
112 //              W := C1**T
113 //
114                 auto W = Work(_(1,lastC),_(1,k));
115                 blas::copy(Trans, C1, W);
116 //
117 //              W := W * V1
118 //
119                 blas::mm(Right, NoTrans, One, V1.lowerUnit(), W);
120 
121                 if (lastV>k) {
122 //
123 //                  W := W + C2**T *V2
124 //
125                     blas::mm(Trans, NoTrans, One, C2, V2, One, W);
126                 }
127 //
128 //              W := W * T**T  or  W * T
129 //
130                 blas::mm(Right, transT, One, Tr, W);
131 //
132 //              C := C - V * W**T
133 //
134                 if (lastV>k) {
135 //
136 //                  C2 := C2 - V2 * W**T
137 //
138                     blas::mm(NoTrans, Trans, -One, V2, W, One, C2);
139                 }
140 //
141 //              W := W * V1**T
142 //
143                 blas::mm(Right, Trans, One, V1.lowerUnit(), W);
144 //
145 //              C1 := C1 - W**T
146 //
147                 blas::axpy(Trans, -One, W, C1);
148             } else if (side==Right) {
149 //
150 //              Form  C * H  or  C * H**T  where  C = ( C1  C2 )
151 //
152                 const IndexType lastV = max(k, ilalr(V));
153                 const auto V1 = V(_(1,k),_);
154                 const auto V2 = V(_(k+1,lastV),_);
155 
156                 const IndexType lastC = ilalr(C(_,_(1,lastV)));
157                 auto C1 = C(_(1,lastC),_(1,k));
158                 auto C2 = C(_(1,lastC),_(k+1,lastV));
159 //
160 //              W := C * V  =  (C1*V1 + C2*V2)  (stored in WORK)
161 //
162 //              W := C1
163 //
164                 auto W = Work(_(1,lastC),_(1,k));
165                 W = C1;
166 //
167 //              W := W * V1
168 //
169                 blas::mm(Right, NoTrans, One, V1.lowerUnit(), W);
170 
171                 if (lastV>k) {
172 //
173 //                  W := W + C2 * V2
174 //
175                     blas::mm(NoTrans, NoTrans, One, C2, V2, One, W);
176                 }
177 //
178 //              W := W * T  or  W * T**T
179 //
180                 blas::mm(Right, transH, One, Tr, W);
181 //
182 //              C := C - W * V**T
183 //
184                 if (lastV>k) {
185 //
186 //                  C2 := C2 - W * V2**T
187 //
188                     blas::mm(NoTrans, Trans, -One, W, V2, One, C2);
189                 }
190 //
191 //              W := W * V1**T
192 //
193                 blas::mm(Right, Trans, One, V1.lowerUnit(), W);
194 //
195 //              C1 := C1 - W
196 //
197                 blas::axpy(NoTrans, -One, W, C1);
198             }
199         } else if (direction==Backward) {
200             // Lehn: I will implement it as soon as someone needs it
201             ASSERT(0);
202         }
203     } else if (storeVectors==RowWise) {
204 
205         if (direction==Forward) {
206 //
207 //          Let  V =  ( V1  V2 )    (V1: first K columns)
208 //          where  V1  is unit upper triangular.
209 //
210             if (side==Left) {
211 //
212 //              Form  H * C  or  H**T * C  where  C = ( C1 )
213 //                                                    ( C2 )
214 //
215                 const IndexType lastV = max(k, ilalc(V));
216                 const auto V1 = V(_,_(1,k));
217                 const auto V2 = V(_,_(k+1,lastV));
218 
219                 const IndexType lastC = ilalc(C(_(1,lastV),_));
220                 auto C1 = C(_(  1,    k),_(1,lastC));
221                 auto C2 = C(_(k+1,lastV),_(1,lastC));
222 //
223 //              W := C**T * V**T  =  (C1**T * V1**T + C2**T * V2**T)
224 //                                                              (stored in WORK)
225 //              W := C1**T
226 //
227                 auto W = Work(_(1,lastC),_(1,k));
228                 blas::copy(Trans, C1, W);
229 //
230 //              W := W * V1**T
231 //
232                 blas::mm(Right, Trans, One, V1.upperUnit(), W);
233 
234                 if (lastV>k) {
235 //
236 //                  W := W + C2**T*V2**T
237 //
238                     blas::mm(Trans, Trans, One, C2, V2, One, W);
239                 }
240 //
241 //              W := W * T**T  or  W * T
242 //
243                 blas::mm(Right, transT, One, Tr, W);
244 //
245 //              C := C - V**T * W**T
246 //
247                 if (lastV>k) {
248 //
249 //                  C2 := C2 - V2**T * W**T
250 //
251                     blas::mm(Trans, Trans, -One, V2, W, One, C2);
252                 }
253 //
254 //              W := W * V1
255 //
256                 blas::mm(Right, NoTrans, One, V1.upperUnit(), W);
257 //
258 //              C1 := C1 - W**T
259 //
260                 blas::axpy(Trans, -One, W, C1);
261             } else if (side==Right) {
262 //
263 //              Form  C * H  or  C * H**T  where  C = ( C1  C2 )
264 //
265                 const IndexType lastV = max(k, ilalc(V));
266                 const auto V1 = V(_,_(1,k));
267                 const auto V2 = V(_,_(k+1,lastV));
268 
269                 const IndexType lastC = ilalr(C(_,_(1,lastV)));
270                 auto C1 = C(_(1,lastC),_(1,k));
271                 auto C2 = C(_(1,lastC),_(k+1,lastV));
272 //
273 //              W := C * V**T  =  (C1*V1**T + C2*V2**T)  (stored in WORK)
274 //
275 //              W := C1
276 //
277                 auto W = Work(_(1,lastC),_(1,k));
278                 W = C1;
279 //
280 //              W := W * V1**T
281 //
282                 blas::mm(Right, Trans, One, V1.upperUnit(), W);
283 
284                 if (lastV>k) {
285 //
286 //                  W := W + C2 * V2**T
287 //
288                     blas::mm(NoTrans, Trans, One, C2, V2, One, W);
289                 }
290 //
291 //              W := W * T  or  W * T**T
292 //
293                 blas::mm(Right, transH, One, Tr, W);
294 //
295 //              C := C - W * V
296 //
297                 if (lastV>k) {
298 //
299 //                  C2 := C2 - W * V2
300 //
301                     blas::mm(NoTrans, NoTrans, -One, W, V2, One, C2);
302                 }
303 //
304 //              W := W * V1
305 //
306                 blas::mm(Right, NoTrans, One, V1.upperUnit(), W);
307 //
308 //              C1 := C1 - W
309 //
310                 blas::axpy(NoTrans, -One, W, C1);
311             }
312         } else if (direction==Backward) {
313             if (side==Left) {
314                 // Lehn: I will implement it as soon as someone needs it
315                 ASSERT(0);
316             } else if (side==Right) {
317                 // Lehn: I will implement it as soon as someone needs it
318                 ASSERT(0);
319             }
320         }
321     }
322 }
323 
324 //== interface for native lapack ===============================================
325 
326 #ifdef CHECK_CXXLAPACK
327 
328 template <typename MV, typename MT, typename MC, typename MWORK>
329 void
330 larfb_native(Side                   side,
331              Transpose              transH,
332              Direction              direction,
333              StoreVectors           storeVectors,
334              const GeMatrix<MV>     &V,
335              const TrMatrix<MT>     &Tr,
336              GeMatrix<MC>           &C,
337              GeMatrix<MWORK>        &Work)
338 {
339     typedef typename TrMatrix<MT>::ElementType  T;
340 
341     const char      SIDE    = char(side);
342     const char      TRANS   = getF77LapackChar(transH);
343     const char      DIRECT  = char(direction);
344     const char      STOREV  = char(storeVectors);
345     const INTEGER   M       = C.numRows();
346     const INTEGER   N       = C.numCols();
347     const INTEGER   K       = Tr.dim();
348     const INTEGER   LDV     = V.leadingDimension();
349     const INTEGER   LDT     = Tr.leadingDimension();
350     const INTEGER   LDC     = C.leadingDimension();
351     const INTEGER   LDWORK  = Work.leadingDimension();
352 
353     if (IsSame<T, DOUBLE>::value) {
354         LAPACK_IMPL(dlarfb)(&SIDE,
355                             &TRANS,
356                             &DIRECT,
357                             &STOREV,
358                             &M,
359                             &N,
360                             &K,
361                             V.data(),
362                             &LDV,
363                             Tr.data(),
364                             &LDT,
365                             C.data(),
366                             &LDC,
367                             Work.data(),
368                             &LDWORK);
369     } else {
370         ASSERT(0);
371     }
372 }
373 
374 #endif // CHECK_CXXLAPACK
375 
376 //== public interface ==========================================================
377 
378 template <typename MV, typename MT, typename MC, typename MWORK>
379 void
380 larfb(Side                  side,
381       Transpose             transH,
382       Direction             direction,
383       StoreVectors          storeV,
384       const GeMatrix<MV>    &V,
385       const TrMatrix<MT>    &Tr,
386       GeMatrix<MC>          &C,
387       GeMatrix<MWORK>       &Work)
388 {
389     LAPACK_DEBUG_OUT("larfb");
390 
391 //
392 //  Test the input parameters
393 //
394 #   ifndef NDEBUG
395     ASSERT(transH!=Conj);
396 
397     if (side==Left) {
398         ASSERT(Work.numRows()>=C.numCols());
399     } else {
400         ASSERT(Work.numRows()>=C.numRows());
401     }
402     ASSERT(Work.numCols()==Tr.dim());
403 #   endif
404 
405 #   ifdef CHECK_CXXLAPACK
406 //
407 //  Make copies of output arguments
408 //
409     typename GeMatrix<MC>::NoView       C_org = C;
410     typename GeMatrix<MWORK>::NoView    Work_org = Work;
411 #   endif
412 
413 //
414 //  Call implementation
415 //
416     larfb_generic(side, transH, direction, storeV, V, Tr, C, Work);
417 
418 #   ifdef CHECK_CXXLAPACK
419 //
420 //  Restore output arguments
421 //
422     typename GeMatrix<MC>::NoView       C_generic = C;
423     typename GeMatrix<MWORK>::NoView    Work_generic = Work;
424 
425     C    = C_org;
426     Work = Work_org;
427 //
428 //  Compare results
429 //
430     larfb_native(side, transH, direction, storeV, V, Tr, C, Work);
431 
432     bool failed = false;
433     if (! isIdentical(C_generic, C, "C_generic""C")) {
434         std::cerr << "CXXLAPACK: C_generic = " << C_generic << std::endl;
435         std::cerr << "F77LAPACK: C = " << C << std::endl;
436         failed = true;
437     }
438     if (! isIdentical(Work_generic, Work, " Work_generic""Work")) {
439         std::cerr << "CXXLAPACK: Work_generic = " << Work_generic << std::endl;
440         std::cerr << "F77LAPACK: Work = " << Work << std::endl;
441         failed = true;
442     }
443     if (failed) {
444         std::cerr << "side =      " << char(side) << std::endl;
445         std::cerr << "transH =    " << transH << std::endl;
446         std::cerr << "direction = " << char(direction) << std::endl;
447         std::cerr << "storeV =    " << char(storeV) << std::endl;
448         ASSERT(0);
449     }
450 #   endif
451 }
452 
453 //-- forwarding ----------------------------------------------------------------
454 template <typename MV, typename MT, typename MC, typename MWORK>
455 void
456 larfb(Side              side,
457       Transpose         transH,
458       Direction         direction,
459       StoreVectors      storeV,
460       const MV          &V,
461       const MT          &Tr,
462       MC                &&C,
463       MWORK             &&Work)
464 {
465     larfb(side, transH, direction, storeV, V, Tr, C, Work);
466 }
467 
468 } } // namespace lapack, flens
469 
470 #endif // FLENS_LAPACK_AUX_LARFB_TCC