1 /*
  2  *   Copyright (c) 2009, Michael Lehn
  3  *
  4  *   All rights reserved.
  5  *
  6  *   Redistribution and use in source and binary forms, with or without
  7  *   modification, are permitted provided that the following conditions
  8  *   are met:
  9  *
 10  *   1) Redistributions of source code must retain the above copyright
 11  *      notice, this list of conditions and the following disclaimer.
 12  *   2) Redistributions in binary form must reproduce the above copyright
 13  *      notice, this list of conditions and the following disclaimer in
 14  *      the documentation and/or other materials provided with the
 15  *      distribution.
 16  *   3) Neither the name of the FLENS development group nor the names of
 17  *      its contributors may be used to endorse or promote products derived
 18  *      from this software without specific prior written permission.
 19  *
 20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 21  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 22  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 23  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 24  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 26  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 27  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 28  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 29  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 #ifndef FLENS_BLAS_LEVEL3_MM_TCC
 34 #define FLENS_BLAS_LEVEL3_MM_TCC
 35 
 36 #include <flens/typedefs.h>
 37 
 38 namespace flens { namespace blas {
 39 
 40 //== product type: GeneralMatrix - GeneralMatrix products
 41 
 42 //-- gemm
 43 template <typename ALPHA, typename MA, typename MB, typename BETA, typename MC>
 44 void
 45 mm(Transpose transA, Transpose transB,
 46    const ALPHA &alpha,
 47    const GeMatrix<MA> &A, const GeMatrix<MB> &B,
 48    const BETA &beta,
 49    GeMatrix<MC> &C)
 50 {
 51 #   ifndef NDEBUG
 52     int kA = (transA==NoTrans) ? A.numCols() : A.numRows();
 53     int kB = (transB==NoTrans) ? B.numRows() : B.numCols();
 54     ASSERT(kA==kB);
 55 #   endif
 56 
 57     typedef typename GeMatrix<MC>::IndexType IndexType;
 58     IndexType m = (transA==NoTrans) ? A.numRows() : A.numCols();
 59     IndexType n = (transB==NoTrans) ? B.numCols() : B.numRows();
 60     IndexType k = (transA==NoTrans) ? A.numCols() : A.numRows();
 61 
 62     if (MC::order!=MA::order) {
 63         transA = Transpose(transA ^ Trans);
 64     }
 65     if (MC::order!=MB::order) {
 66         transB = Transpose(transB ^ Trans);
 67     }
 68 
 69 #   ifndef NDEBUG
 70     if (beta!=BETA(0)) {
 71         if (C.numRows()!=0 && C.numCols()!=0) {
 72             ASSERT(C.numRows()==m && C.numCols()==n);
 73         }
 74     }
 75 #   endif
 76 
 77     if ((C.numRows()!=m) || (C.numCols()!=n)) {
 78         C.resize(m, n);
 79     }
 80 
 81 #   ifndef FLENS_DEBUG_CLOSURES
 82     ASSERT(!DEBUGCLOSURE::identical(A, C));
 83     ASSERT(!DEBUGCLOSURE::identical(B, C));
 84 #   else
 85 //
 86 //  If A or B is identical with C we copy C into a temporary first.  Then
 87 //  we compute the matrix-matrix product and afterwards copy the result into C.
 88 //
 89     if (DEBUGCLOSURE::identical(A, C) || DEBUGCLOSURE::identical(B, C)) {
 90         typename GeMatrix<MC>::NoView _C;
 91         FLENS_BLASLOG_TMP_ADD(_C);
 92 
 93         if (beta!=BETA(0)) {
 94             _C = C;
 95         }
 96         mm(transA, transB, alpha, A, B, beta, _C);
 97         C = _C;
 98 
 99         FLENS_BLASLOG_TMP_REMOVE(_C, C);
100         return;
101     }
102 #   endif
103 
104     FLENS_BLASLOG_SETTAG("--> ");
105     FLENS_BLASLOG_BEGIN_GEMM(transA, transB, alpha, A, B, beta, C);
106 
107 #   ifdef HAVE_CXXBLAS_GEMM
108     cxxblas::gemm(MC::order,
109                   transA, transB,
110                   C.numRows(),
111                   C.numCols(),
112                   k,
113                   alpha,
114                   A.data(), A.leadingDimension(),
115                   B.data(), B.leadingDimension(),
116                   beta,
117                   C.data(), C.leadingDimension());
118 #   else
119     ASSERT(0);
120 #   endif
121 
122     FLENS_BLASLOG_END;
123     FLENS_BLASLOG_UNSETTAG;
124 }
125 
126 //== product type: TriangularMatrix - GeneralMatrix products
127 
128 //-- trmm
129 template <typename ALPHA, typename MA, typename MB>
130 void
131 mm(Side side,
132    Transpose transA, const ALPHA &alpha, const TrMatrix<MA> &A,
133    GeMatrix<MB> &B)
134 {
135 #   ifndef NDEBUG
136     ASSERT(MB::order==MA::order);
137     if (side==Left) {
138         assert(A.dim()==B.numRows());
139     } else {
140         assert(B.numCols()==A.dim());
141     }
142 #   endif
143 
144     FLENS_BLASLOG_SETTAG("--> ");
145     FLENS_BLASLOG_BEGIN_TRMM(side, transA, alpha, A, B);
146 
147 #   ifdef HAVE_CXXBLAS_TRMM
148     cxxblas::trmm(MB::order, side,
149                   A.upLo(), transA, A.diag(),
150                   B.numRows(), B.numCols(),
151                   alpha,
152                   A.data(), A.leadingDimension(),
153                   B.data(), B.leadingDimension());
154 #   else
155     ASSERT(0);
156 #   endif
157 
158     FLENS_BLASLOG_END;
159     FLENS_BLASLOG_UNSETTAG;
160 }
161 
162 
163 //== product type: SymmetricMatrix - GeneralMatrix products
164 
165 //-- symm
166 template <typename ALPHA, typename MA, typename MB, typename BETA, typename MC>
167 void
168 mm(Side side,
169    const ALPHA &alpha, const SyMatrix<MA> &A, const GeMatrix<MB> &B,
170    const BETA &beta, GeMatrix<MC> &C)
171 {
172 #   ifndef NDEBUG
173     ASSERT(MC::order==MB::order);
174     if (side==Left) {
175         ASSERT(A.dim()==B.numRows());
176     } else {
177         ASSERT(B.numCols()==A.dim());
178     }
179 #   endif
180 
181     StorageUpLo upLo = (MC::order==MA::order)
182                      ? A.upLo()
183                      : StorageUpLo(! A.upLo());
184 
185     typedef typename GeMatrix<MC>::IndexType IndexType;
186     IndexType m = (side==Left) ? A.dim() : B.numRows();
187     IndexType n = (side==Left) ? B.numCols() : A.dim();
188 
189 #   ifndef NDEBUG
190     if (beta!=BETA(0)) {
191         if (C.numRows()!=0 && C.numCols()!=0) {
192             ASSERT(C.numRows()==m && C.numCols()==n);
193         }
194     }
195 #   endif
196 
197     if ((C.numRows()!=m) || (C.numCols()!=n)) {
198         C.resize(m, n);
199     }
200 
201 #   ifndef FLENS_DEBUG_CLOSURES
202     ASSERT(!DEBUGCLOSURE::identical(A, C));
203     ASSERT(!DEBUGCLOSURE::identical(B, C));
204 #   else
205 //
206 //  If A or B is identical with C we copy C into a temporary first.  Then
207 //  we compute the matrix-matrix product and afterwards copy the result into C.
208 //
209     if (DEBUGCLOSURE::identical(A, C) || DEBUGCLOSURE::identical(B, C)) {
210         typename GeMatrix<MC>::NoView _C;
211         FLENS_BLASLOG_TMP_ADD(_C);
212 
213         if (beta!=BETA(0)) {
214             _C = C;
215         }
216         mm(side, alpha, A, B, beta, _C);
217         C = _C;
218 
219         FLENS_BLASLOG_TMP_REMOVE(_C, C);
220         return;
221     }
222 #   endif
223 
224     FLENS_BLASLOG_SETTAG("--> ");
225     FLENS_BLASLOG_BEGIN_SYMM(side, alpha, A, B, beta, C);
226 
227 #   ifdef HAVE_CXXBLAS_SYMM
228     cxxblas::symm(MC::order, side,
229                   upLo,
230                   C.numRows(), C.numCols(),
231                   alpha,
232                   A.data(), A.leadingDimension(),
233                   B.data(), B.leadingDimension(),
234                   beta,
235                   C.data(), C.leadingDimension());
236 #   else
237     ASSERT(0);
238 #   endif
239 
240     FLENS_BLASLOG_END;
241     FLENS_BLASLOG_UNSETTAG;
242 }
243 
244 //== product type: HermitianMatrix - GeneralMatrix products
245 
246 //-- hemm
247 template <typename ALPHA, typename MA, typename MB, typename BETA, typename MC>
248 void
249 mm(Side side,
250    const ALPHA &alpha, const HeMatrix<MA> &A, const GeMatrix<MB> &B,
251    const BETA &beta, GeMatrix<MC> &C)
252 {
253 #   ifndef NDEBUG
254     ASSERT(MC::order==MB::Oorder);
255     if (side==Left) {
256         ASSERT(A.dim()==B.numRows());
257     } else {
258         ASSERT(B.numCols()==A.dim());
259     }
260 #   endif
261 
262     StorageUpLo upLo = (MC::order==MA::order)
263                      ? A.upLo()
264                      : StorageUpLo(! A.upLo());
265 
266     typedef typename GeMatrix<MC>::IndexType IndexType;
267     IndexType m = (side==Left) ? A.dim() : B.numRows();
268     IndexType n = (side==Left) ? B.numCols() : A.dim();
269  
270     ASSERT((beta==static_cast<BETA>(0)) || (C.numRows()==m));
271     ASSERT((beta==static_cast<BETA>(0)) || (C.numCols()==n));
272  
273     if ((C.numRows()!=m) || (C.numCols()!=n)) {
274         C.resize(m,n);
275     }
276 
277 #   ifdef HAVE_CXXBLAS_HEMM
278     cxxblas::hemm(MC::order, side,
279                   upLo,
280                   C.numRows(), C.numCols(),
281                   alpha,
282                   A.data(), A.leadingDimension(),
283                   B.data(), B.leadingDimension(),
284                   beta,
285                   C.data(), C.leadingDimension());
286 #   else
287     ASSERT(0);
288 #   endif
289 }
290 
291 //== Forwarding ================================================================
292 
293 //-- GeneralMatrix - GeneralMatrix products
294 template <typename ALPHA, typename MA, typename MB, typename BETA, typename MC>
295 typename RestrictTo<IsGeneralMatrix<MA>::value &&
296                     IsGeneralMatrix<MB>::value &&
297                    !IsClosure<MA>::value &&
298                    !IsClosure<MB>::value &&
299                     IsSame<MC, typename MC::Impl>::value,
300          void>::Type
301 mm(Transpose transA, Transpose transB, const ALPHA &alpha,
302    const MA &A, const MB &B, const BETA &beta, MC &&C)
303 {
304     CHECKPOINT_ENTER;
305     mm(transA, transB, alpha, A, B, beta, C);
306     CHECKPOINT_LEAVE;
307 }
308 
309 //-- TriangularMatrix - GeneralMatrix products
310 template <typename ALPHA, typename MA, typename MB>
311 typename RestrictTo<IsTriangularMatrix<MA>::value &&
312                     IsGeneralMatrix<MB>::value &&
313                    !IsClosure<MA>::value &&
314                     IsSame<MB, typename MB::Impl>::value,
315          void>::Type
316 mm(Side side, Transpose transA, const ALPHA &alpha, const MA &A, MB &&B)
317 {
318     CHECKPOINT_ENTER;
319     mm(side, transA, alpha, A, B);
320     CHECKPOINT_LEAVE;
321 }
322 
323 
324 //-- SymmetricMatrix - GeneralMatrix products
325 template <typename ALPHA, typename MA, typename MB, typename BETA, typename MC>
326 typename RestrictTo<IsSymmetricMatrix<MA>::value &&
327                     IsGeneralMatrix<MB>::value &&
328                    !IsClosure<MA>::value &&
329                    !IsClosure<MB>::value &&
330                     IsSame<MC, typename MC::Impl>::value,
331          void>::Type
332 mm(Side side, const ALPHA &alpha, const MA &A, const MB &B,
333    const BETA &beta, MC &&C)
334 {
335     CHECKPOINT_ENTER;
336     mm(side, alpha, A, B, beta, C);
337     CHECKPOINT_LEAVE;
338 }
339 
340 //-- HermitianMatrix - GeneralMatrix products
341 template <typename ALPHA, typename MA, typename MB, typename BETA, typename MC>
342 typename RestrictTo<IsHermitianMatrix<MA>::value &&
343                     IsGeneralMatrix<MB>::value &&
344                    !IsClosure<MA>::value &&
345                    !IsClosure<MB>::value &&
346                     IsSame<MC, typename MC::Impl>::value,
347          void>::Type
348 mm(Side side, const ALPHA &alpha, const MA &A, const MB &B,
349    const BETA &beta, MC &&C)
350 {
351     CHECKPOINT_ENTER;
352     mm(side, alpha, A, B, beta, C);
353     CHECKPOINT_LEAVE;
354 }
355 
356 
357 } } // namespace blas, flens
358 
359 #endif // FLENS_BLAS_LEVEL3_MM_TCC