1 /*
  2  *   Copyright (c) 2011, Michael Lehn
  3  *
  4  *   All rights reserved.
  5  *
  6  *   Redistribution and use in source and binary forms, with or without
  7  *   modification, are permitted provided that the following conditions
  8  *   are met:
  9  *
 10  *   1) Redistributions of source code must retain the above copyright
 11  *      notice, this list of conditions and the following disclaimer.
 12  *   2) Redistributions in binary form must reproduce the above copyright
 13  *      notice, this list of conditions and the following disclaimer in
 14  *      the documentation and/or other materials provided with the
 15  *      distribution.
 16  *   3) Neither the name of the FLENS development group nor the names of
 17  *      its contributors may be used to endorse or promote products derived
 18  *      from this software without specific prior written permission.
 19  *
 20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 21  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 22  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 23  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 24  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 26  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 27  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 28  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 29  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 /* Based on
 34  *
 35        SUBROUTINE DGELS( TRANS, M, N, NRHS, A, LDA, B, LDB, WORK, LWORK,
 36       $                  INFO )
 37  *
 38  *  -- LAPACK driver routine (version 3.3.1) --
 39  *  -- LAPACK is a software package provided by Univ. of Tennessee,    --
 40  *  -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
 41  *  -- April 2011                                                      --
 42  */
 43 
 44 #ifndef FLENS_LAPACK_LS_LS_TCC
 45 #define FLENS_LAPACK_LS_LS_TCC 1
 46 
 47 #include <flens/blas/blas.h>
 48 #include <flens/lapack/lapack.h>
 49 
 50 namespace flens { namespace lapack {
 51 
 52 //== generic lapack implementation =============================================
 53 
 54 template <typename MA, typename MB, typename VWORK>
 55 typename GeMatrix<MA>::IndexType
 56 ls_generic(Transpose                 trans,
 57            GeMatrix<MA>              &A,
 58            GeMatrix<MB>              &B,
 59            DenseVector<VWORK>        &work)
 60 {
 61     using std::max;
 62     using flens::min;
 63 
 64     typedef typename GeMatrix<MB>::ElementType  ElementType;
 65     typedef typename GeMatrix<MB>::IndexType    IndexType;
 66 
 67     const ElementType  Zero(0), One(1);
 68 
 69     const Underscore<IndexType>  _;
 70 
 71     const IndexType  m = A.numRows();
 72     const IndexType  n = A.numCols();
 73     const IndexType  nRhs = B.numCols();
 74     const IndexType  mn = min(m,n);
 75 
 76     IndexType info = 0;
 77 //
 78 //  Figure out optimal block size
 79 //
 80     IndexType  nb;
 81     bool       tpsd = (trans==NoTrans) ? false : true;
 82 
 83     if (m>=n) {
 84         nb = ilaenv<ElementType>(1"GEQRF""", m, n);
 85         if (tpsd) {
 86             nb = max(nb, ilaenv<ElementType>(1"ORMQR""LN", m, nRhs, n));
 87         } else {
 88             nb = max(nb, ilaenv<ElementType>(1"ORMQR""LT", m, nRhs, n));
 89         }
 90     } else {
 91         nb = ilaenv<ElementType>(1"GELQF""", m, n);
 92         if (tpsd) {
 93             nb = max(nb, ilaenv<ElementType>(1"ORMLQ""LT", n, nRhs, m));
 94         } else {
 95             nb = max(nb, ilaenv<ElementType>(1"ORMLQ""LN", n, nRhs, m));
 96         }
 97     }
 98 
 99     IndexType  wSize = max(IndexType(1), mn+max(mn,nRhs)*nb);
100     if (work.length()==0) {
101         work.resize(wSize);
102     }
103     work(1) = ElementType(wSize);
104     const IndexType  lWork = work.length();
105 //
106 //  Quick return if possible
107 //
108     if (min(m, n, nRhs)==IndexType(0)) {
109         B = Zero;
110         return info;
111     }
112 //
113 //  Get machine parameters
114 //
115     ElementType  smallNum = lamch<ElementType>(SafeMin)
116                           / lamch<ElementType>(Precision);
117     ElementType  bigNum = One / smallNum;
118     labad(smallNum, bigNum);
119 //
120 //  Scale A, B if max element outside range [SMLNUM,BIGNUM]
121 //
122     const ElementType  normA = lan(MaximumNorm, A);
123     IndexType iScaleA = 0;
124 
125     if (normA>Zero && normA<smallNum) {
126 //
127 //      Scale matrix norm up to SMLNUM
128 //
129         lascl(LASCL::FullMatrix, 00, normA, smallNum, A);
130         iScaleA = 1;
131     } else if (normA>bigNum) {
132 //
133 //      Scale matrix norm down to BIGNUM
134 //
135         lascl(LASCL::FullMatrix, 00, normA, bigNum, A);
136         iScaleA = 2;
137     } else if (normA==Zero) {
138 //
139 //      Matrix all zero. Return zero solution.
140 //
141         lascl(LASCL::FullMatrix, max(m,n), nRhs, Zero, Zero, A);
142         work(1) = wSize;
143     }
144 
145     auto _B = (tpsd) ? B(_(1,n),_) : B(_(1,m),_);
146     const ElementType  normB = lan(MaximumNorm, _B);
147     IndexType iScaleB = 0;
148     IndexType scaleLen = 0;
149 
150     if (normB>Zero && normB<smallNum) {
151 //
152 //      Scale matrix norm up to SMLNUM
153 //
154         lascl(LASCL::FullMatrix, 00, normB, smallNum, B);
155         iScaleB = 1;
156     } else if (normB>bigNum) {
157 //
158 //      Scale matrix norm down to BIGNUM
159 //
160         lascl(LASCL::FullMatrix, 00, normB, bigNum, B);
161         iScaleB = 2;
162     }
163 
164     auto tau   = work(_(1,mn));
165     auto _work = work(_(mn+1,lWork));
166     if (m>=n) {
167 //
168 //      compute QR factorization of A
169 //
170         qrf(A, tau, _work);
171         const auto R = A(_(1,n),_(1,n)).upper();
172 //
173 //      workspace at least N, optimally N*NB
174 //
175         if (!tpsd) {
176 //
177 //          Least-Squares Problem min || A * X - B ||
178 //
179 //          B(1:M,1:NRHS) := Q**T * B(1:M,1:NRHS)
180 //
181             ormqr(Left, Trans, A, tau, B, _work);
182 //
183 //          workspace at least NRHS, optimally NRHS*NB
184 //
185 //          B(1:N,1:NRHS) := inv(R) * B(1:N,1:NRHS)
186 //
187             info = trs(NoTrans, R, B(_(1,n),_));
188 
189             if (info>0) {
190                 return info;
191             }
192             scaleLen = n;
193 
194         } else {
195 //
196 //          Overdetermined system of equations A**T * X = B
197 //
198 //          B(1:N,1:NRHS) := inv(R**T) * B(1:N,1:NRHS)
199 //
200             info = trs(Trans, R, B(_(1,n),_));
201 
202             if (info>0) {
203                 return info;
204             }
205 //
206 //          B(N+1:M,1:NRHS) = ZERO
207 //
208             B(_(n+1,m),_) = Zero;
209 //
210 //          B(1:M,1:NRHS) := Q(1:N,:) * B(1:N,1:NRHS)
211 //
212             ormqr(Left, NoTrans, A, tau, B, _work);
213 //
214 //          workspace at least NRHS, optimally NRHS*NB
215 //
216             scaleLen = m;
217 
218         }
219 
220     } else {
221 //
222 //      Compute LQ factorization of A
223 //
224         lqf(A, tau, _work);
225         const auto L = A(_(1,m),_(1,m)).lower();
226 //
227 //       workspace at least M, optimally M*NB.
228 //
229         if (!tpsd) {
230 //
231 //          underdetermined system of equations A * X = B
232 //
233 //          B(1:M,1:NRHS) := inv(L) * B(1:M,1:NRHS)
234 //
235             info = trs(NoTrans, L, B(_(1,m),_));
236 
237             if (info>0) {
238                 return info;
239             }
240 //
241 //          B(M+1:N,1:NRHS) = 0
242 //
243             B(_(m+1,n),_) = Zero;
244 //
245 //          B(1:N,1:NRHS) := Q(1:N,:)**T * B(1:M,1:NRHS)
246 //
247             ormlq(Left, Trans, A, tau, B, _work);
248 //
249 //          workspace at least NRHS, optimally NRHS*NB
250 //
251             scaleLen = n;
252 //
253         } else {
254 //
255 //          overdetermined system min || A**T * X - B ||
256 //
257 //          B(1:N,1:NRHS) := Q * B(1:N,1:NRHS)
258 //
259             ormlq(Left, NoTrans, A, tau, B, _work);
260 //
261 //          workspace at least NRHS, optimally NRHS*NB
262 //
263 //          B(1:M,1:NRHS) := inv(L**T) * B(1:M,1:NRHS)
264 //
265             info = trs(Trans, L, B(_(1,m),_));
266 
267             if (info>0) {
268                 return info;
269             }
270 
271             scaleLen = m;
272 
273         }
274 
275     }
276 //
277 //  Undo scaling
278 //
279     auto __B = B(_(1,scaleLen),_);
280     if (iScaleA==1) {
281         lascl(LASCL::FullMatrix, 00, normA, smallNum, __B);
282     } else if (iScaleA==2) {
283         lascl(LASCL::FullMatrix, 00, normA, bigNum, __B);
284     }
285     if (iScaleB==1) {
286         lascl(LASCL::FullMatrix, 00, smallNum, normB, __B);
287     } else if (iScaleB==2) {
288         lascl(LASCL::FullMatrix, 00, bigNum, normB, __B);
289     }
290 
291     work(1) = ElementType(wSize);
292     return info;
293 }
294 
295 //== interface for native lapack ===============================================
296 
297 #ifdef CHECK_CXXLAPACK
298 
299 template <typename MA, typename MB, typename VWORK>
300 typename GeMatrix<MA>::IndexType
301 ls_native(Transpose                 trans,
302           GeMatrix<MA>              &A,
303           GeMatrix<MB>              &B,
304           DenseVector<VWORK>        &work)
305 {
306     typedef typename GeMatrix<MA>::ElementType  T;
307 
308     const char      TRANS = getF77LapackChar(trans);
309     const INTEGER   M = A.numRows();
310     const INTEGER   N = A.numCols();
311     const INTEGER   NRHS = B.numCols();
312     const INTEGER   LDA = A.leadingDimension();
313     const INTEGER   LDB = B.leadingDimension();
314     const INTEGER   LWORK = work.length();
315     INTEGER         INFO;
316 
317     if (IsSame<T,DOUBLE>::value) {
318         LAPACK_IMPL(dgels)(&TRANS,
319                            &M,
320                            &N,
321                            &NRHS,
322                            A.data(),
323                            &LDA,
324                            B.data(),
325                            &LDB,
326                            work.data(),
327                            &LWORK,
328                            &INFO);
329     } else {
330         ASSERT(0);
331     }
332 
333     ASSERT(INFO>=0);
334     return INFO;
335 }
336 
337 #endif // CHECK_CXXLAPACK
338 
339 //== public interface ==========================================================
340 
341 template <typename MA, typename MB, typename VWORK>
342 typename GeMatrix<MA>::IndexType
343 ls(Transpose                 trans,
344    GeMatrix<MA>              &A,
345    GeMatrix<MB>              &B,
346    DenseVector<VWORK>        &work)
347 {
348     using std::max;
349     using std::min;
350 //
351 //  Test the input parameters
352 //
353 #   ifndef NDEBUG
354     typedef typename GeMatrix<MA>::IndexType    IndexType;
355 
356     const IndexType m = A.numRows();
357     const IndexType n = A.numCols();
358     const IndexType nRhs = B.numCols();
359 
360     ASSERT(B.numRows()==max(m,n));
361 
362     if (work.length()>0) {
363         const IndexType mn = min(m, n);
364         ASSERT(work.length()>=max(IndexType(1),mn+max(mn,nRhs)));
365     }
366 #   endif
367 
368 //
369 //  Make copies of output arguments
370 //
371 #   ifdef CHECK_CXXLAPACK
372     typename GeMatrix<MA>::NoView       A_org      = A;
373     typename GeMatrix<MB>::NoView       B_org      = B;
374     typename DenseVector<VWORK>::NoView work_org   = work;
375 #   endif
376 
377 //
378 //  Call implementation
379 //
380     IndexType info = ls_generic(trans, A, B, work);
381 
382 #   ifdef CHECK_CXXLAPACK
383 //
384 //  Make copies of results computed by the generic implementation
385 //
386     typename GeMatrix<MA>::NoView       A_generic       = A;
387     typename GeMatrix<MB>::NoView       B_generic       = B;
388     typename DenseVector<VWORK>::NoView work_generic    = work;
389 
390 //
391 //  restore output arguments
392 //
393     A = A_org;
394     B = B_org;
395     work = work_org;
396 
397 //
398 //  Compare generic results with results from the native implementation
399 //
400     IndexType _info = ls_native(trans, A, B, work);
401 
402     bool failed = false;
403     if (! isIdentical(A_generic, A, "A_generic""A")) {
404         std::cerr << "CXXLAPACK: A_generic = " << A_generic << std::endl;
405         std::cerr << "F77LAPACK: A = " << A << std::endl;
406         failed = true;
407     }
408     if (! isIdentical(B_generic, B, "B_generic""B")) {
409         std::cerr << "CXXLAPACK: B_generic = " << B_generic << std::endl;
410         std::cerr << "F77LAPACK: B = " << B << std::endl;
411         failed = true;
412     }
413     if (! isIdentical(work_generic, work, "work_generic""work")) {
414         std::cerr << "CXXLAPACK: work_generic = " << work_generic << std::endl;
415         std::cerr << "F77LAPACK: work = " << work << std::endl;
416         failed = true;
417     }
418     if (! isIdentical(info, _info, "info""_info")) {
419         std::cerr << "CXXLAPACK: info = " << info << std::endl;
420         std::cerr << "F77LAPACK: _info = " << _info << std::endl;
421         failed = true;
422     }
423 
424     if (failed) {
425         std::cerr << "error in: ls.tcc" << std::endl;
426         ASSERT(0);
427     } else {
428         // std::cerr << "passed: ls.tcc" << std::endl;
429     }
430 #   endif
431 
432     return info;
433 }
434 
435 //-- forwarding ----------------------------------------------------------------
436 template <typename MA, typename MB, typename VWORK>
437 typename MA::IndexType
438 ls(Transpose               trans,
439    MA                      &&A,
440    MB                      &&B,
441    VWORK                   &&work)
442 {
443     typename MA::IndexType info;
444 
445     CHECKPOINT_ENTER;
446     info = ls(trans, A, B, work);
447     CHECKPOINT_LEAVE;
448 
449     return info;
450 }
451 
452 } } // namespace lapack, flens
453 
454 #endif // FLENS_LAPACK_LS_LS_TCC