1 /*
  2  *   Copyright (c) 2009, Michael Lehn
  3  *
  4  *   All rights reserved.
  5  *
  6  *   Redistribution and use in source and binary forms, with or without
  7  *   modification, are permitted provided that the following conditions
  8  *   are met:
  9  *
 10  *   1) Redistributions of source code must retain the above copyright
 11  *      notice, this list of conditions and the following disclaimer.
 12  *   2) Redistributions in binary form must reproduce the above copyright
 13  *      notice, this list of conditions and the following disclaimer in
 14  *      the documentation and/or other materials provided with the
 15  *      distribution.
 16  *   3) Neither the name of the FLENS development group nor the names of
 17  *      its contributors may be used to endorse or promote products derived
 18  *      from this software without specific prior written permission.
 19  *
 20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 21  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 22  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 23  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 24  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 26  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 27  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 28  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 29  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 #ifndef CXXBLAS_LEVEL3_TRMM_TCC
 34 #define CXXBLAS_LEVEL3_TRMM_TCC 1
 35 
 36 namespace cxxblas {
 37 
 38 template <typename IndexType, typename ALPHA, typename MA, typename MB>
 39 void
 40 trmm_generic(StorageOrder order, Side sideA, StorageUpLo upLoA,
 41              Transpose transA, Diag diagA,
 42              IndexType m, IndexType n,
 43              const ALPHA &alpha,
 44              const MA *A, IndexType ldA,
 45              MB *B, IndexType ldB)
 46 {
 47     if (order==ColMajor) {
 48         sideA = (sideA==Left) ? Right : Left;
 49         upLoA = (upLoA==Upper) ? Lower : Upper;
 50         trmm_generic(RowMajor, sideA, upLoA, transA, diagA, n, m,
 51                      alpha, A, ldA, B, ldB);
 52         return;
 53     }
 54     if (sideA==Right) {
 55         transA = Transpose(transA^Trans);
 56         for (IndexType i=0; i<m; ++i) {
 57             trmv(order, upLoA, transA, diagA, n, A, ldA, B+i*ldB, IndexType(1));
 58         }
 59     }
 60     if (sideA==Left) {
 61         for (IndexType j=0; j<n; ++j) {
 62             trmv(order, upLoA, transA, diagA, m, A, ldA, B+j, ldB);
 63         }
 64     }
 65     gescal(order, m, n, alpha, B, ldB);    
 66 }
 67 
 68 template <typename IndexType, typename ALPHA, typename MA, typename MB>
 69 void
 70 trmm(StorageOrder order, Side side, StorageUpLo upLo,
 71      Transpose transA, Diag diag,
 72      IndexType m, IndexType n,
 73      const ALPHA &alpha,
 74      const MA *A, IndexType ldA,
 75      MB *B, IndexType ldB)
 76 {
 77     CXXBLAS_DEBUG_OUT("trmm_generic");
 78 
 79     trmm_generic(order, side, upLo, transA, diag, m, n, alpha, A, ldA, B, ldB);
 80 }
 81 
 82 #ifdef HAVE_CBLAS
 83 
 84 // strmm
 85 template <typename IndexType>
 86 typename If<IndexType>::isBlasCompatibleInteger
 87 trmm(StorageOrder order, Side side, StorageUpLo upLo,
 88      Transpose transA, Diag diag,
 89      IndexType m, IndexType n,
 90      float alpha,
 91      const float *A, IndexType ldA,
 92      float *B, IndexType ldB)
 93 {
 94     CXXBLAS_DEBUG_OUT("[" BLAS_IMPL "] cblas_strmm");
 95 
 96     cblas_strmm(CBLAS::getCblasType(order),
 97                 CBLAS::getCblasType(side), CBLAS::getCblasType(upLo),
 98                 CBLAS::getCblasType(transA), CBLAS::getCblasType(diag),
 99                 m, n,
100                 alpha,
101                 A, ldA,
102                 B, ldB);
103 }
104 
105 // dtrmm
106 template <typename IndexType>
107 typename If<IndexType>::isBlasCompatibleInteger
108 trmm(StorageOrder order, Side side, StorageUpLo upLo,
109      Transpose transA, Diag diag,
110      IndexType m, IndexType n,
111      double alpha,
112      const double *A, IndexType ldA,
113      double *B, IndexType ldB)
114 {
115     CXXBLAS_DEBUG_OUT("[" BLAS_IMPL "] cblas_dtrmm");
116 
117     cblas_dtrmm(CBLAS::getCblasType(order),
118                 CBLAS::getCblasType(side), CBLAS::getCblasType(upLo),
119                 CBLAS::getCblasType(transA), CBLAS::getCblasType(diag),
120                 m, n,
121                 alpha,
122                 A, ldA,
123                 B, ldB);
124 }
125 
126 // ctrmm
127 template <typename IndexType>
128 typename If<IndexType>::isBlasCompatibleInteger
129 trmm(StorageOrder order, Side side, StorageUpLo upLo,
130      Transpose transA, Diag diag,
131      IndexType m, IndexType n,
132      const ComplexFloat &alpha,
133      const ComplexFloat *A, IndexType ldA,
134      ComplexFloat *B, IndexType ldB)
135 {
136     CXXBLAS_DEBUG_OUT("[" BLAS_IMPL "] cblas_ctrmm");
137 
138     cblas_ctrmm(CBLAS::getCblasType(order),
139                 CBLAS::getCblasType(side), CBLAS::getCblasType(upLo),
140                 CBLAS::getCblasType(transA), CBLAS::getCblasType(diag),
141                 m, n,
142                 reinterpret_cast<const float *>(&alpha),
143                 reinterpret_cast<const float *>(A), ldA,
144                 reinterpret_cast<const float *>(B), ldB);
145 }
146 
147 // ztrmm
148 template <typename IndexType>
149 typename If<IndexType>::isBlasCompatibleInteger
150 trmm(StorageOrder order, Side side, StorageUpLo upLo,
151      Transpose transA, Diag diag,
152      IndexType m, IndexType n,
153      const ComplexDouble &alpha,
154      const ComplexDouble *A, IndexType ldA,
155      ComplexDouble *B, IndexType ldB)
156 {
157     CXXBLAS_DEBUG_OUT("[" BLAS_IMPL "] cblas_ztrmm");
158 
159     cblas_ztrmm(CBLAS::getCblasType(order),
160                 CBLAS::getCblasType(side), CBLAS::getCblasType(upLo),
161                 CBLAS::getCblasType(transA), CBLAS::getCblasType(diag),
162                 m, n,
163                 reinterpret_cast<const double *>(&alpha),
164                 reinterpret_cast<const double *>(A), ldA,
165                 reinterpret_cast<const double *>(B), ldB);
166 }
167 
168 #endif // HAVE_CBLAS
169 
170 // namespace cxxblas
171 
172 #endif // CXXBLAS_LEVEL3_TRMM_TCC