1 /*
  2  *   Copyright (c) 2011, Michael Lehn
  3  *
  4  *   All rights reserved.
  5  *
  6  *   Redistribution and use in source and binary forms, with or without
  7  *   modification, are permitted provided that the following conditions
  8  *   are met:
  9  *
 10  *   1) Redistributions of source code must retain the above copyright
 11  *      notice, this list of conditions and the following disclaimer.
 12  *   2) Redistributions in binary form must reproduce the above copyright
 13  *      notice, this list of conditions and the following disclaimer in
 14  *      the documentation and/or other materials provided with the
 15  *      distribution.
 16  *   3) Neither the name of the FLENS development group nor the names of
 17  *      its contributors may be used to endorse or promote products derived
 18  *      from this software without specific prior written permission.
 19  *
 20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 21  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 22  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 23  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 24  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 26  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 27  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 28  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 29  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 /* Besed on
 34  *
 35        SUBROUTINE DGETRS( TRANS, N, NRHS, A, LDA, IPIV, B, LDB, INFO )
 36  *
 37  *  -- LAPACK routine (version 3.3.1) --
 38  *  -- LAPACK is a software package provided by Univ. of Tennessee,    --
 39  *  -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
 40  *  -- April 2011                                                      --
 41  */
 42 
 43 #ifndef FLENS_LAPACK_GESV_TRS_TCC
 44 #define FLENS_LAPACK_GESV_TRS_TCC 1
 45 
 46 #include <flens/blas/blas.h>
 47 #include <flens/lapack/lapack.h>
 48 
 49 namespace flens { namespace lapack {
 50 
 51 //== generic lapack implementation =============================================
 52 // getrs
 53 template <typename MA, typename VP, typename MB>
 54 void
 55 trs_generic(Transpose trans, const GeMatrix<MA> &A, const DenseVector<VP> &piv,
 56             GeMatrix<MB> &B)
 57 {
 58     typedef typename GeMatrix<MA>::IndexType    IndexType;
 59     typedef typename GeMatrix<MA>::ElementType  T;
 60 
 61     const IndexType n       = A.numCols();
 62     const IndexType nRhs    = B.numCols();
 63 
 64     const T  One(1);
 65 //
 66 //  Quick return if possible
 67 //
 68     if ((n==0) || (nRhs==0)) {
 69         return;
 70     }
 71 
 72     if ((trans==NoTrans) || (trans==Conj)) {
 73 //
 74 //      Solve A * X = B.
 75 //
 76 //      Apply row interchanges to the right hand sides.
 77 //
 78         laswp(B, piv);
 79 //
 80 //      Solve L*X = B, overwriting B with X.
 81 //
 82         blas::sm(Left, trans, One, A.lowerUnit(), B);
 83 //
 84 //      Solve U*X = B, overwriting B with X.
 85 //
 86         blas::sm(Left, trans, One, A.upper(), B);
 87     } else {
 88 //
 89 //      Solve A' * X = B.
 90 //
 91 //      Solve U'*X = B, overwriting B with X.
 92 //
 93         blas::sm(Left, trans, One, A.upper(), B);
 94 //
 95 //      Solve L'*X = B, overwriting B with X.
 96 //
 97         blas::sm(Left, trans, One, A.lowerUnit(), B);
 98 //
 99 //      Apply row interchanges to the solution vectors.
100 //
101         laswp(B, piv.reverse());
102     }
103 }
104 
105 // trtrs
106 template <typename MA, typename MB>
107 typename TrMatrix<MA>::IndexType
108 trs_generic(Transpose trans, const TrMatrix<MA> &A, GeMatrix<MB> &B)
109 {
110     typedef typename TrMatrix<MA>::IndexType    IndexType;
111     typedef typename TrMatrix<MA>::ElementType  T;
112 
113     const IndexType n       = A.dim();
114 
115     const T  Zero(0), One(1);
116 
117     IndexType info = 0;
118 //
119 //  Quick return if possible
120 //
121     if (n==0) {
122         return info;
123     }
124 //
125 //  Check for singularity.
126 //
127     if (A.diag()!=Unit) {
128         for (info=1; info<=n; ++info) {
129             if (A(info,info)==Zero) {
130                 return info;
131             }
132         }
133     }
134     info = 0;
135 //
136 //  Solve A * x = b  or  A**T * x = b.
137 //
138     blas::sm(Left, trans, One, A, B);
139 
140     return info;
141 }
142 
143 //== interface for native lapack ===============================================
144 
145 #ifdef CHECK_CXXLAPACK
146 
147 // getrs
148 template <typename MA, typename VP, typename MB>
149 void
150 trs_native(Transpose trans, const GeMatrix<MA> &A, const DenseVector<VP> &piv,
151            GeMatrix<MB> &B)
152 {
153     typedef typename GeMatrix<MA>::ElementType ElementType;
154 
155     const char       TRANS = getF77LapackChar(trans);
156     const INTEGER    N     = A.numRows();
157     const INTEGER    NRHS  = B.numCols();
158     const INTEGER    LDA   = A.leadingDimension();
159     const INTEGER    LDB   = B.leadingDimension();
160     INTEGER          INFO;
161 
162 
163     if (IsSame<ElementType, double>::value) {
164         LAPACK_IMPL(dgetrs)(&TRANS,
165                             &N,
166                             &NRHS,
167                             A.data(),
168                             &LDA,
169                             piv.data(),
170                             B.data(),
171                             &LDB,
172                             &INFO);
173     } else {
174         ASSERT(0);
175     }
176     ASSERT(INFO==0);
177 }
178 
179 // trtrs
180 template <typename MA, typename MB>
181 typename TrMatrix<MA>::IndexType
182 trs_native(Transpose trans, const TrMatrix<MA> &A, GeMatrix<MB> &B)
183 {
184     typedef typename TrMatrix<MA>::ElementType ElementType;
185 
186     const char       UPLO = char(A.upLo());
187     const char       TRANS = getF77LapackChar(trans);
188     const char       DIAG  = char(A.diag());
189     const INTEGER    N     = A.dim();
190     const INTEGER    NRHS  = B.numCols();
191     const INTEGER    LDA   = A.leadingDimension();
192     const INTEGER    LDB   = B.leadingDimension();
193     INTEGER          INFO;
194 
195 
196     if (IsSame<ElementType, double>::value) {
197         LAPACK_IMPL(dtrtrs)(&UPLO,
198                             &TRANS,
199                             &DIAG,
200                             &N,
201                             &NRHS,
202                             A.data(),
203                             &LDA,
204                             B.data(),
205                             &LDB,
206                             &INFO);
207     } else {
208         ASSERT(0);
209     }
210     ASSERT(INFO>=0);
211     return INFO;
212 }
213 
214 #endif // CHECK_CXXLAPACK
215 
216 //== public interface ==========================================================
217 
218 // getrs
219 template <typename MA, typename VP, typename MB>
220 void
221 trs(Transpose trans, const GeMatrix<MA> &A, const DenseVector<VP> &piv,
222     GeMatrix<MB> &B)
223 {
224     typedef typename GeMatrix<MA>::IndexType  IndexType;
225 //
226 //  Test the input parameters
227 //
228 #   ifndef NDEBUG
229     ASSERT(A.firstRow()==1);
230     ASSERT(A.firstCol()==1);
231     ASSERT(A.numRows()==A.numCols());
232 
233     const IndexType n = A.numRows();
234 
235     ASSERT(piv.firstIndex()==1);
236     ASSERT(piv.length()==n);
237 
238     ASSERT(B.firstRow()==1);
239     ASSERT(B.firstCol()==1);
240     ASSERT(B.numRows()==n);
241 #   endif
242 
243 //
244 //  Make copies of output arguments
245 //
246     typename GeMatrix<MB>::NoView  B_org   = B;
247 //
248 //  Call implementation
249 //
250     trs_generic(trans, A, piv, B);
251 //
252 //  Compare results
253 //
254 #   ifdef CHECK_CXXLAPACK
255     typename GeMatrix<MB>::NoView  B_generic   = B;
256 
257     B   = B_org;
258 
259     trs_native(trans, A, piv, B);
260 
261     bool failed = false;
262     if (! isIdentical(B_generic, B, "B_generic""B")) {
263         std::cerr << "CXXLAPACK: B_generic = " << B_generic << std::endl;
264         std::cerr << "F77LAPACK: B = " << B << std::endl;
265         failed = true;
266     }
267 
268     if (failed) {
269         ASSERT(0);
270     } else {
271         // std::cerr << "passed: (ge)trs.tcc" << std::endl;
272     }
273 
274 #   endif
275 }
276 
277 template <typename MA, typename VP, typename VB>
278 void
279 trs(Transpose trans, const GeMatrix<MA> &A, const DenseVector<VP> &piv,
280     DenseVector<VB> &b)
281 {
282     typedef typename DenseVector<VB>::ElementType  ElementType;
283     typedef typename DenseVector<VB>::IndexType    IndexType;
284 
285     const IndexType    n     = b.length();
286     const StorageOrder order = GeMatrix<MA>::Engine::order;
287 
288     GeMatrix<FullStorageView<ElementType, order> >  B(n, 1, b, n);
289 
290     return trs(trans, A, piv, B);
291 }
292 
293 // trtrs
294 template <typename MA, typename MB>
295 typename TrMatrix<MA>::IndexType
296 trs(Transpose trans, const TrMatrix<MA> &A, GeMatrix<MB> &B)
297 {
298     typedef typename TrMatrix<MA>::IndexType  IndexType;
299 //
300 //  Test the input parameters
301 //
302 #   ifndef NDEBUG
303     ASSERT(A.firstRow()==1);
304     ASSERT(A.firstCol()==1);
305 
306     const IndexType n = A.dim();
307 
308     ASSERT(B.firstRow()==1);
309     ASSERT(B.firstCol()==1);
310     ASSERT(B.numRows()==n);
311 #   endif
312 
313 //
314 //  Make copies of output arguments
315 //
316     typename GeMatrix<MB>::NoView  B_org   = B;
317 //
318 //  Call implementation
319 //
320     IndexType info = trs_generic(trans, A, B);
321 //
322 //  Compare results
323 //
324 #   ifdef CHECK_CXXLAPACK
325     typename GeMatrix<MB>::NoView  B_generic   = B;
326 
327     B   = B_org;
328 
329     IndexType _info = trs_native(trans, A, B);
330 
331     bool failed = false;
332     if (! isIdentical(B_generic, B, "B_generic""B")) {
333         std::cerr << "CXXLAPACK: B_generic = " << B_generic << std::endl;
334         std::cerr << "F77LAPACK: B = " << B << std::endl;
335         failed = true;
336     }
337 
338     if (! isIdentical(info, _info, "info""_info")) {
339         std::cerr << "CXXLAPACK: info = " << info << std::endl;
340         std::cerr << "F77LAPACK: _info = " << _info << std::endl;
341         failed = true;
342     }
343 
344     if (failed) {
345         ASSERT(0);
346     } else {
347         // std::cerr << "passed: (tr)trs.tcc" << std::endl;
348     }
349 #   endif
350 
351     return info;
352 }
353 
354 template <typename MA, typename VB>
355 typename TrMatrix<MA>::IndexType
356 trs(Transpose trans, const TrMatrix<MA> &A, DenseVector<VB> &b)
357 {
358     typedef typename DenseVector<VB>::ElementType  ElementType;
359     typedef typename DenseVector<VB>::IndexType    IndexType;
360 
361     const IndexType    n     = b.length();
362     const StorageOrder order = TrMatrix<MA>::Engine::order;
363 
364     GeMatrix<FullStorageView<ElementType, order> >  B(n, 1, b, n);
365 
366     return trs(trans, A, B);
367 }
368 
369 //-- forwarding ----------------------------------------------------------------
370 template <typename MA, typename VP, typename MB>
371 void
372 trs(Transpose trans, const MA &A, const VP &piv, MB &&B)
373 {
374     CHECKPOINT_ENTER;
375     trs(trans, A, piv, B);
376     CHECKPOINT_LEAVE;
377 }
378 
379 template <typename MA, typename MB>
380 typename MA::IndexType
381 trs(Transpose trans, const MA &A, MB &&B)
382 {
383     typename MA::IndexType info;
384 
385     CHECKPOINT_ENTER;
386     info = trs(trans, A, B);
387     CHECKPOINT_LEAVE;
388 
389     return info;
390 }
391 
392 } } // namespace lapack, flens
393 
394 #endif // FLENS_LAPACK_GESV_TRS_TCC