1 /*
  2  *   Copyright (c) 2011, Michael Lehn
  3  *
  4  *   All rights reserved.
  5  *
  6  *   Redistribution and use in source and binary forms, with or without
  7  *   modification, are permitted provided that the following conditions
  8  *   are met:
  9  *
 10  *   1) Redistributions of source code must retain the above copyright
 11  *      notice, this list of conditions and the following disclaimer.
 12  *   2) Redistributions in binary form must reproduce the above copyright
 13  *      notice, this list of conditions and the following disclaimer in
 14  *      the documentation and/or other materials provided with the
 15  *      distribution.
 16  *   3) Neither the name of the FLENS development group nor the names of
 17  *      its contributors may be used to endorse or promote products derived
 18  *      from this software without specific prior written permission.
 19  *
 20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 21  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 22  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 23  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 24  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 26  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 27  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 28  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 29  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 /* Based on
 34  *
 35        SUBROUTINE DPOTRS( UPLO, N, NRHS, A, LDA, B, LDB, INFO )
 36  *
 37  *  -- LAPACK routine (version 3.3.1) --
 38  *  -- LAPACK is a software package provided by Univ. of Tennessee,    --
 39  *  -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
 40  *  -- April 2011                                                      --
 41  */
 42 
 43 #ifndef FLENS_LAPACK_GESV_POTRS_TCC
 44 #define FLENS_LAPACK_GESV_POTRS_TCC 1
 45 
 46 #include <algorithm>
 47 #include <flens/blas/blas.h>
 48 #include <flens/lapack/lapack.h>
 49 
 50 #include <flens/lapack/interface/include/f77lapack.h>
 51 
 52 namespace flens { namespace lapack {
 53 
 54 //== generic lapack implementation =============================================
 55 
 56 template <typename MA, typename MB>
 57 void
 58 potrs_generic(const SyMatrix<MA> &A, GeMatrix<MB> &B)
 59 {
 60     using std::isnan;
 61     using std::sqrt;
 62 
 63     typedef typename SyMatrix<MA>::ElementType  ElementType;
 64     typedef typename SyMatrix<MA>::IndexType    IndexType;
 65 
 66 
 67     const IndexType n    = A.dim();
 68     const IndexType nRhs = B.numCols();
 69     const bool upper     = (A.upLo()==Upper);
 70 
 71     const ElementType  One(1);
 72 //
 73 //  Quick return if possible
 74 //
 75     if (n==0 || nRhs==0) {
 76         return;
 77     }
 78     if (upper) {
 79 //
 80 //      Solve A*X = B where A = U**T *U.
 81 //
 82 //      Solve U**T *X = B, overwriting B with X.
 83 //
 84         blas::sm(Left, Trans, One, A.triangular(), B);
 85 //
 86 //      Solve U*X = B, overwriting B with X.
 87 //
 88         blas::sm(Left, NoTrans, One, A.triangular(), B);
 89     } else {
 90 //
 91 //      Solve A*X = B where A = L*L**T.
 92 //
 93 //      Solve L*X = B, overwriting B with X.
 94 //
 95         blas::sm(Left, NoTrans, One, A.triangular(), B);
 96 //
 97 //      Solve L**T *X = B, overwriting B with X.
 98 //
 99         blas::sm(Left, Trans, One, A.triangular(), B);
100     }
101 }
102 
103 //== interface for native lapack ===============================================
104 
105 #ifdef CHECK_CXXLAPACK
106 
107 template <typename MA, typename MB>
108 void
109 potrs_native(const SyMatrix<MA> &A, GeMatrix<MB> &B)
110 {
111     typedef typename SyMatrix<MA>::ElementType  T;
112 
113     const char       UPLO = char(A.upLo());
114     const INTEGER    N    = A.dim();
115     const INTEGER    NRHS = B.numCols();
116     const INTEGER    LDA  = A.leadingDimension();
117     const INTEGER    LDB  = B.leadingDimension();
118     INTEGER          INFO;
119 
120     if (IsSame<T, DOUBLE>::value) {
121         LAPACK_IMPL(dpotrs)(&UPLO, &N, &NRHS,
122                             A.data(), &LDA,
123                             B.data(), &LDB,
124                             &INFO);
125     } else {
126         ASSERT(0);
127     }
128 }
129 
130 #endif // CHECK_CXXLAPACK
131 
132 //== public interface ==========================================================
133 
134 template <typename MA, typename MB>
135 void
136 potrs(const SyMatrix<MA> &A, GeMatrix<MB> &B)
137 {
138     typedef typename SyMatrix<MA>::IndexType    IndexType;
139 
140 //
141 //  Test the input parameters
142 //
143     ASSERT(A.firstRow()==1);
144     ASSERT(A.firstCol()==1);
145 
146     ASSERT(B.firstRow()==1);
147     ASSERT(B.firstCol()==1);
148 
149     ASSERT(A.dim()==B.numRows());
150 
151 #   ifdef CHECK_CXXLAPACK
152 //
153 //  Make copies of output arguments
154 //
155     typename GeMatrix<MB>::NoView B_org = B;
156 #   endif
157 
158 //
159 //  Call implementation
160 //
161     potrs_generic(A, B);
162 
163 #   ifdef CHECK_CXXLAPACK
164 //
165 //  Make copies of generic results
166 //
167     typename GeMatrix<MB>::NoView B_generic = B;
168 //
169 //  Restore output arguments
170 //
171     B = B_org;
172 
173 //
174 //  Compare results
175 //
176     potrs_native(A, B);
177 
178     bool failed = false;
179     if (! isIdentical(B_generic, B, "B_generic""B")) {
180         std::cerr << "CXXLAPACK: B_generic = " << B_generic << std::endl;
181         std::cerr << "F77LAPACK: B = " << B << std::endl;
182         failed = true;
183     }
184 
185     if (failed) {
186         ASSERT(0);
187     }
188 #   endif
189 }
190 
191 //-- forwarding ----------------------------------------------------------------
192 template <typename MA, typename MB>
193 void
194 potrs(const MA &A, MB &&B)
195 {
196     CHECKPOINT_ENTER;
197     potrs(A, B);
198     CHECKPOINT_LEAVE;
199 }
200 
201 } } // namespace lapack, flens
202 
203 #endif // FLENS_LAPACK_GESV_POTRS_TCC