1 /*
  2  *   Copyright (c) 2009, Michael Lehn
  3  *
  4  *   All rights reserved.
  5  *
  6  *   Redistribution and use in source and binary forms, with or without
  7  *   modification, are permitted provided that the following conditions
  8  *   are met:
  9  *
 10  *   1) Redistributions of source code must retain the above copyright
 11  *      notice, this list of conditions and the following disclaimer.
 12  *   2) Redistributions in binary form must reproduce the above copyright
 13  *      notice, this list of conditions and the following disclaimer in
 14  *      the documentation and/or other materials provided with the
 15  *      distribution.
 16  *   3) Neither the name of the FLENS development group nor the names of
 17  *      its contributors may be used to endorse or promote products derived
 18  *      from this software without specific prior written permission.
 19  *
 20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 21  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 22  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 23  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 24  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 26  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 27  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 28  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 29  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 #ifndef CXXBLAS_LEVEL3_GEMM_TCC
 34 #define CXXBLAS_LEVEL3_GEMM_TCC 1
 35 
 36 namespace cxxblas {
 37 
 38 template <typename IndexType, typename ALPHA, typename MA, typename MB,
 39           typename BETA, typename MC>
 40 void
 41 gemm_generic(StorageOrder order,
 42              Transpose transA, Transpose transB,
 43              IndexType m, IndexType n, IndexType k,
 44              const ALPHA &alpha,
 45              const MA *A, IndexType ldA,
 46              const MB *B, IndexType ldB,
 47              const BETA &beta,
 48              MC *C, IndexType ldC)
 49 {
 50     CXXBLAS_DEBUG_OUT("gemm_generic");
 51 
 52     if ((m==0) || (n==0)) {
 53         return;
 54     }
 55     if (order==ColMajor) {
 56         gemm_generic(RowMajor, transB, transA,
 57                      n, m, k, alpha,
 58                      B, ldB, A, ldA,
 59                      beta,
 60                      C, ldC);
 61         return;
 62     }
 63 
 64     gescal(order, m, n, beta, C, ldC);
 65     if (alpha==ALPHA(0)) {
 66         return;
 67     }
 68     if ((transA==NoTrans) && (transB==NoTrans)) {
 69         for (IndexType l=0; l<n; ++l) {
 70             gemv(order, NoTrans, m, k, alpha, A, ldA, B+l, ldB,
 71                  BETA(1), C+l, ldC);
 72         }
 73     }
 74     if ((transA==NoTrans) && (transB==Conj)) {
 75         for (IndexType l=0; l<n; ++l) {
 76             gemv(order, NoTrans, Conj, m, k, alpha, A, ldA, B+l, ldB,
 77                  BETA(1), C+l, ldC);
 78         }
 79     }
 80     if ((transA==NoTrans) && (transB==Trans)) {
 81         for (IndexType l=0; l<n; ++l) {
 82             gemv(order, NoTrans, m, k, alpha, A, ldA, B+l*ldB, IndexType(1),
 83                  BETA(1), C+l, ldC);
 84         }
 85     }
 86     if ((transA==NoTrans) && (transB==ConjTrans)) {
 87         for (IndexType l=0; l<n; ++l) {
 88             gemv(order, NoTrans, Conj, m, k,
 89                  alpha, A, ldA, B+l*ldB, IndexType(1),
 90                  BETA(1), C+l, ldC);
 91         }
 92     }
 93 
 94     if ((transA==Conj) && (transB==NoTrans)) {
 95         for (IndexType l=0; l<n; ++l) {
 96             gemv(order, NoTrans, m, k, alpha, A, ldA, B+l, ldB,
 97                  BETA(1), C+l, ldC);
 98         }
 99     }
100     if ((transA==Conj) && (transB==Conj)) {
101         for (IndexType l=0; l<n; ++l) {
102             gemv(order, Conj, Conj, m, k, alpha, A, ldA, B+l, ldB,
103                  BETA(1), C+l, ldC);
104         }
105     }
106     if ((transA==Conj) && (transB==Trans)) {
107         for (IndexType l=0; l<n; ++l) {
108             gemv(order, Conj, m, k, alpha, A, ldA, B+l*ldB, IndexType(1),
109                  BETA(1), C+l, ldC);
110         }
111     }
112     if ((transA==Conj) && (transB==ConjTrans)) {
113         for (IndexType l=0; l<n; ++l) {
114             gemv(order, Conj, Conj, m, k, alpha, A, ldA, B+l*ldB, IndexType(1),
115                  BETA(1), C+l, ldC);
116         }
117     }
118 
119     if ((transA==Trans) && (transB==NoTrans)) {
120         for (IndexType l=0; l<n; ++l) {
121             gemv(order, Trans, k, m, alpha, A, ldA, B+l, ldB,
122                  BETA(1), C+l, ldC);
123         }
124     }
125     if ((transA==Trans) && (transB==Conj)) {
126         for (IndexType l=0; l<n; ++l) {
127             gemv(order, Trans, Conj, k, m, alpha, A, ldA, B+l, ldB,
128                  BETA(1), C+l, ldC);
129         }
130     }
131     if ((transA==Trans) && (transB==Trans)) {
132         for (IndexType l=0; l<n; ++l) {
133             gemv(order, Trans, k, m, alpha, A, ldA, B+l*ldB, IndexType(1),
134                  BETA(1), C+l, ldC);
135         }
136     }
137     if ((transA==Trans) && (transB==ConjTrans)) {
138         for (IndexType l=0; l<n; ++l) {
139             gemv(order, Trans, Conj, k, m,
140                  alpha, A, ldA, B+l*ldB, IndexType(1),
141                  BETA(1), C+l, ldC);
142         }
143     }
144 
145     if ((transA==ConjTrans) && (transB==NoTrans)) {
146         for (IndexType l=0; l<n; ++l) {
147             gemv(order, ConjTrans, k, m, alpha, A, ldA, B+l, ldB,
148                  BETA(1), C+l, ldC);
149         }
150     }
151     if ((transA==ConjTrans) && (transB==Conj)) {
152         for (IndexType l=0; l<n; ++l) {
153             gemv(order, ConjTrans, k, m, alpha, A, ldA, B+l, ldB,
154                  BETA(1), C+l, ldC);
155         }
156     }
157     if ((transA==ConjTrans) && (transB==Trans)) {
158         for (IndexType l=0; l<n; ++l) {
159             gemv(order, ConjTrans, k, m, alpha, A, ldA, B+l*ldB, IndexType(1),
160                  BETA(1), C+l, ldC);
161         }
162     }
163     if ((transA==ConjTrans) && (transB==ConjTrans)) {
164         for (IndexType l=0; l<n; ++l) {
165             gemv(order, ConjTrans, Conj, k, m,
166                  alpha, A, ldA, B+l*ldB, IndexType(1),
167                  BETA(1), C+l, ldC);
168         }
169     }
170 }
171 
172 template <typename IndexType, typename ALPHA, typename MA, typename MB,
173           typename BETA, typename MC>
174 void
175 gemm(StorageOrder order,
176      Transpose transA, Transpose transB,
177      IndexType m, IndexType n, IndexType k,
178      const ALPHA &alpha,
179      const MA *A, IndexType ldA,
180      const MB *B, IndexType ldB,
181      const BETA &beta,
182      MC *C, IndexType ldC)
183 {
184     gemm_generic(order, transA, transB, m, n, k,
185                  alpha, A, ldA, B, ldB,
186                  beta,
187                  C, ldC);
188 }
189 
190 #ifdef HAVE_CBLAS
191 
192 // sgemm
193 template <typename IndexType>
194 typename If<IndexType>::isBlasCompatibleInteger
195 gemm(StorageOrder order,
196      Transpose transA, Transpose transB,
197      IndexType m, IndexType n, IndexType k,
198      float alpha,
199      const float *A, IndexType ldA,
200      const float *B, IndexType ldB,
201      float beta,
202      float *C, IndexType ldC)
203 {
204     CXXBLAS_DEBUG_OUT("[" BLAS_IMPL "] cblas_sgemm");
205 
206     cblas_sgemm(CBLAS::getCblasType(order),
207                 CBLAS::getCblasType(transA), CBLAS::getCblasType(transB),
208                 m, n, k,
209                 alpha,
210                 A, ldA,
211                 B, ldB,
212                 beta,
213                 C, ldC);
214 }
215 
216 // dgemm
217 template <typename IndexType>
218 typename If<IndexType>::isBlasCompatibleInteger
219 gemm(StorageOrder order,
220      Transpose transA, Transpose transB,
221      IndexType m, IndexType n, IndexType k,
222      double alpha,
223      const double *A, IndexType ldA,
224      const double *B, IndexType ldB,
225      double beta,
226      double *C, IndexType ldC)
227 {
228     CXXBLAS_DEBUG_OUT("[" BLAS_IMPL "] cblas_dgemm");
229 
230     cblas_dgemm(CBLAS::getCblasType(order),
231                 CBLAS::getCblasType(transA), CBLAS::getCblasType(transB),
232                 m, n, k,
233                 alpha,
234                 A, ldA,
235                 B, ldB,
236                 beta,
237                 C, ldC);
238 }
239 
240 // cgemm
241 template <typename IndexType>
242 typename If<IndexType>::isBlasCompatibleInteger
243 gemm(StorageOrder order,
244      Transpose transA, Transpose transB,
245      IndexType m, IndexType n, IndexType k,
246      const ComplexFloat &alpha,
247      const ComplexFloat *A, IndexType ldA,
248      const ComplexFloat *B, IndexType ldB,
249      const ComplexFloat &beta,
250      ComplexFloat *C, IndexType ldC)
251 {
252     CXXBLAS_DEBUG_OUT("[" BLAS_IMPL "] cblas_cgemm");
253 
254     cblas_cgemm(CBLAS::getCblasType(order),
255                 CBLAS::getCblasType(transA), CBLAS::getCblasType(transB),
256                 m, n, k,
257                 reinterpret_cast<const float *>(&alpha),
258                 reinterpret_cast<const float *>(A), ldA,
259                 reinterpret_cast<const float *>(B), ldB,
260                 reinterpret_cast<const float *>(&beta),
261                 reinterpret_cast<const float *>(C), ldC);
262 }
263 
264 // zgemm
265 template <typename IndexType>
266 typename If<IndexType>::isBlasCompatibleInteger
267 gemm(StorageOrder order,
268      Transpose transA, Transpose transB,
269      IndexType m, IndexType n, IndexType k,
270      const ComplexDouble &alpha,
271      const ComplexDouble *A, IndexType ldA,
272      const ComplexDouble *B, IndexType ldB,
273      const ComplexDouble &beta,
274      ComplexDouble *C, IndexType ldC)
275 {
276     CXXBLAS_DEBUG_OUT("[" BLAS_IMPL "] cblas_zgemm");
277 
278     cblas_zgemm(CBLAS::getCblasType(order),
279                 CBLAS::getCblasType(transA), CBLAS::getCblasType(transB),
280                 m, n, k,
281                 reinterpret_cast<const double *>(&alpha),
282                 reinterpret_cast<const double *>(A), ldA,
283                 reinterpret_cast<const double *>(B), ldB,
284                 reinterpret_cast<const double *>(&beta),
285                 reinterpret_cast<const double *>(C), ldC);
286 }
287 
288 #endif // HAVE_CBLAS
289 
290 // namespace cxxblas
291 
292 #endif // CXXBLAS_LEVEL3_GEMM_TCC