1 /*
  2  *   Copyright (c) 2011, Michael Lehn
  3  *
  4  *   All rights reserved.
  5  *
  6  *   Redistribution and use in source and binary forms, with or without
  7  *   modification, are permitted provided that the following conditions
  8  *   are met:
  9  *
 10  *   1) Redistributions of source code must retain the above copyright
 11  *      notice, this list of conditions and the following disclaimer.
 12  *   2) Redistributions in binary form must reproduce the above copyright
 13  *      notice, this list of conditions and the following disclaimer in
 14  *      the documentation and/or other materials provided with the
 15  *      distribution.
 16  *   3) Neither the name of the FLENS development group nor the names of
 17  *      its contributors may be used to endorse or promote products derived
 18  *      from this software without specific prior written permission.
 19  *
 20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 21  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 22  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 23  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 24  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 26  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 27  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 28  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 29  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 /* Based on
 34  *
 35        SUBROUTINE DLAHR2( N, K, NB, A, LDA, TAU, T, LDT, Y, LDY )
 36  *
 37  *  -- LAPACK auxiliary routine (version 3.3.1)                        --
 38  *  -- LAPACK is a software package provided by Univ. of Tennessee,    --
 39  *  -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
 40  *  -- April 2009                                                      --
 41  */
 42 
 43 #ifndef FLENS_LAPACK_EIG_LAHR2_TCC
 44 #define FLENS_LAPACK_EIG_LAHR2_TCC 1
 45 
 46 #include <flens/blas/blas.h>
 47 #include <flens/lapack/lapack.h>
 48 
 49 namespace flens { namespace lapack {
 50 
 51 //== generic lapack implementation =============================================
 52 
 53 template <typename IndexType, typename MA, typename VTAU,
 54           typename MTR, typename MY>
 55 void
 56 lahr2_generic(IndexType k, IndexType nb, GeMatrix<MA> &A,
 57               DenseVector<VTAU> &tau,   TrMatrix<MTR> &Tr,
 58               GeMatrix<MY> &Y)
 59 {
 60     using lapack::larfg;
 61     using std::min;
 62 
 63     typedef typename GeMatrix<MA>::ElementType  T;
 64 
 65     const Underscore<IndexType> _;
 66 
 67     const IndexType n = A.numRows();
 68     const T         Zero(0), One(1);
 69 
 70 //  TODO: as long as view creation is not supported by the TrMatrix interface
 71 //        we get them through a GeMatrix::View
 72     auto _Tr = Tr.general();
 73 //
 74 //  Quick return if possible
 75 //
 76     if (n<=1) {
 77         return;
 78     }
 79 
 80     T ei = T(0);
 81     for (IndexType i=1; i<=nb; ++i) {
 82         if (i>1) {
 83 //
 84 //          Update A(K+1:N,I)
 85 //
 86 //          Update I-th column of A - Y * V**T
 87 //
 88             blas::mv(NoTrans,
 89                      -One, Y(_(k+1,n),_(1,i-1)), A(k+i-1,_(1,i-1)),
 90                      One, A(_(k+1,n),i));
 91 //
 92 //          Apply I - V * T**T * V**T to this column (call it b) from the
 93 //          left, using the last column of T as workspace
 94 //
 95 //          Let  V = ( V1 )   and   b = ( b1 )   (first I-1 rows)
 96 //                   ( V2 )             ( b2 )
 97 //
 98 //          where V1 is unit lower triangular
 99 //
100 //          w := V1**T * b1
101 //
102             _Tr(_(1,i-1),nb) = A(_(k+1,k+i-1),i);
103             blas::mv(Trans, A(_(k+1,k+i-1),_(1,i-1)).lowerUnit(),
104                      _Tr(_(1,i-1),nb));
105 //
106 //          w := w + V2**T * b2
107 //
108             blas::mv(Trans,
109                      One, A(_(k+i,n),_(1,i-1)), A(_(k+i,n),i),
110                      One, _Tr(_(1,i-1),nb));
111 //
112 //          w := T**T * w
113 //
114             blas::mv(Trans, _Tr(_(1,i-1),_(1,i-1)).upper(), _Tr(_(1,i-1),nb));
115 //
116 //          b2 := b2 - V2*w
117 //
118             blas::mv(NoTrans,
119                      -One, A(_(k+i,n),_(1,i-1)), _Tr(_(1,i-1),nb),
120                      One, A(_(k+i,n),i));
121 //
122 //          b1 := b1 - V1*w
123 //
124             blas::mv(NoTrans,
125                      A(_(k+1,k+i-1),_(1,i-1)).lowerUnit(),
126                      _Tr(_(1,i-1),nb));
127             A(_(k+1,k+i-1),i) -= _Tr(_(1,i-1),nb);
128 
129             A(k+i-1,i-1) = ei;
130         }
131 //
132 //      Generate the elementary reflector H(I) to annihilate
133 //      A(K+I+1:N,I)
134 //
135         larfg(n-k-i+1, A(k+i,i), A(_(min(k+i+1,n),n), i), tau(i));
136 
137         ei = A(k+i, i);
138         A(k+i, i) = One;
139 //
140 //      Compute  Y(K+1:N,I)
141 //
142         blas::mv(NoTrans,
143                  One, A(_(k+1,n),_(i+1,n-k+1)), A(_(k+i,n),i),
144                  Zero, Y(_(k+1,n),i));
145         blas::mv(Trans,
146                  One, A(_(k+i,n),_(1,i-1)), A(_(k+i,n),i),
147                  Zero, _Tr(_(1,i-1),i));
148         blas::mv(NoTrans,
149                  -One, Y(_(k+1,n),_(1,i-1)), _Tr(_(1,i-1),i),
150                  One, Y(_(k+1,n),i));
151         blas::scal(tau(i),Y(_(k+1,n),i));
152 //
153 //      Compute T(1:I,I)
154 //
155         blas::scal(-tau(i), _Tr(_(1,i-1),i));
156         blas::mv(NoTrans, _Tr(_(1,i-1),_(1,i-1)).upper(), _Tr(_(1,i-1),i));
157         Tr(i,i) = tau(i);
158     }
159     A(k+nb, nb) = ei;
160 //
161 //  Compute Y(1:K,1:NB)
162 //
163     Y(_(1,k),_(1,nb)) = A(_(1,k),_(2,2+nb-1));
164     blas::mm(Right, NoTrans,
165              One, A(_(k+1,k+nb),_(1,nb)).lowerUnit(),
166              Y(_(1,k),_(1,nb)));
167 
168     if (n>k+nb) {
169         blas::mm(NoTrans, NoTrans,
170                  One, A(_(1,k),_(2+nb,n-k+1)), A(_(k+1+nb,n),_(1,nb)),
171                  One, Y(_(1,k),_(1,nb)));
172     }
173     blas::mm(Right, NoTrans, One, Tr, Y(_(1,k),_(1,nb)));
174 }
175 
176 //== interface for native lapack ===============================================
177 
178 #ifdef CHECK_CXXLAPACK
179 
180 template <typename IndexType, typename MA, typename VTAU,
181           typename MTR, typename MY>
182 void
183 lahr2_native(IndexType k, IndexType nb, GeMatrix<MA> &A,
184              DenseVector<VTAU> &tau,   TrMatrix<MTR> &Tr,
185              GeMatrix<MY> &Y)
186 {
187     typedef typename  GeMatrix<MY>::ElementType     T;
188 
189     const INTEGER N     = A.numRows();
190     const INTEGER K     = k;
191     const INTEGER NB    = nb;
192     const INTEGER LDA   = A.leadingDimension();
193     const INTEGER LDT   = Tr.leadingDimension();
194     const INTEGER LDY   = Y.leadingDimension();
195 
196     if (IsSame<T,DOUBLE>::value) {
197         LAPACK_IMPL(dlahr2)(&N,
198                             &K,
199                             &NB,
200                             A.data(),
201                             &LDA,
202                             tau.data(),
203                             Tr.data(),
204                             &LDT,
205                             Y.data(),
206                             &LDY);
207     } else {
208         ASSERT(0);
209     }
210 }
211 
212 #endif // CHECK_CXXLAPACK
213 
214 //== public interface ==========================================================
215 
216 template <typename IndexType, typename MA, typename VTAU,
217           typename MTR, typename MY>
218 void
219 lahr2(IndexType k, IndexType nb, GeMatrix<MA> &A, DenseVector<VTAU> &tau,
220       TrMatrix<MTR> &Tr, GeMatrix<MY> &Y)
221 {
222     LAPACK_DEBUG_OUT("lahr2");
223 
224 //
225 //  Test the input parameters
226 //
227     ASSERT(k<A.numRows());
228     ASSERT(A.firstRow()==1);
229     ASSERT(A.firstCol()==1);
230     ASSERT(A.numCols()==A.numRows()-k+1);
231     ASSERT(tau.length()==nb);
232     ASSERT(Tr.dim()==nb);
233     ASSERT(Y.numRows()==A.numRows());
234     ASSERT(Y.numCols()==nb);
235 
236 //
237 //  Make copies of output arguments
238 //
239 #   ifdef CHECK_CXXLAPACK
240     typename GeMatrix<MA>::NoView       _A      = A;
241     typename DenseVector<VTAU>::NoView  _tau    = tau;
242     typename GeMatrix<MTR>::NoView      __Tr    = Tr.general();
243     typename GeMatrix<MY>::NoView       _Y      = Y;
244 
245     auto _Tr = (Tr.upLo()==Upper) ? __Tr.upper() : __Tr.lower();
246 #   endif
247 
248 //
249 //  Call implementation
250 //
251     lahr2_generic(k, nb, A, tau, Tr, Y);
252 
253 //
254 //  Compare results
255 //
256 #   ifdef CHECK_CXXLAPACK
257     lahr2_native(k, nb, _A, _tau, _Tr, _Y);
258 
259     bool failed = false;
260     if (! isIdentical(A, _A, " A""A_")) {
261         std::cerr << "CXXLAPACK:  A = " << A << std::endl;
262         std::cerr << "F77LAPACK: _A = " << _A << std::endl;
263         failed = true;
264     }
265 
266     if (! isIdentical(tau, _tau, " tau""tau_")) {
267         std::cerr << "CXXLAPACK:  tau = " << tau << std::endl;
268         std::cerr << "F77LAPACK: _tau = " << _tau << std::endl;
269         failed = true;
270     }
271 
272     if (! isIdentical(Tr.general(), _Tr.general(), " Tr""_Tr")) {
273         std::cerr << "CXXLAPACK:  Tr = " << Tr.general() << std::endl;
274         std::cerr << "F77LAPACK: _Tr = " << _Tr.general() << std::endl;
275         failed = true;
276     }
277 
278     if (! isIdentical(Y, _Y, " Y""_Y")) {
279         std::cerr << "CXXLAPACK:  Y = " << Y << std::endl;
280         std::cerr << "F77LAPACK: _Y = " << _Y << std::endl;
281         failed = true;
282     }
283 
284     if (failed) {
285         std::cerr << "error in: lahr2.tcc" << std::endl;
286         std::cerr << "k = " << k << ", nb = " << nb << std::endl;
287         ASSERT(0);
288     } else {
289 //        std::cerr << "passed: lahr2.tcc" << std::endl;
290     }
291 #   endif
292 }
293 
294 //-- forwarding ----------------------------------------------------------------
295 template <typename IndexType, typename MA, typename VTAU,
296           typename MTR, typename MY>
297 void
298 lahr2(IndexType k, IndexType nb, MA &&A, VTAU &&tau, MTR &&Tr, MY &&Y)
299 {
300     CHECKPOINT_ENTER;
301     lahr2(k, nb, A, tau, Tr, Y);
302     CHECKPOINT_LEAVE;
303 }
304 
305 } } // namespace lapack, flens
306 
307 #endif // FLENS_LAPACK_EIG_LAHR2_TCC