1 /*
  2  *   Copyright (c) 2012, Michael Lehn
  3  *
  4  *   All rights reserved.
  5  *
  6  *   Redistribution and use in source and binary forms, with or without
  7  *   modification, are permitted provided that the following conditions
  8  *   are met:
  9  *
 10  *   1) Redistributions of source code must retain the above copyright
 11  *      notice, this list of conditions and the following disclaimer.
 12  *   2) Redistributions in binary form must reproduce the above copyright
 13  *      notice, this list of conditions and the following disclaimer in
 14  *      the documentation and/or other materials provided with the
 15  *      distribution.
 16  *   3) Neither the name of the FLENS development group nor the names of
 17  *      its contributors may be used to endorse or promote products derived
 18  *      from this software without specific prior written permission.
 19  *
 20  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 21  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 22  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 23  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 24  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 25  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 26  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 27  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 28  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 29  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 #ifndef FLENS_BLAS_CLOSURES_MM_TCC
 34 #define FLENS_BLAS_CLOSURES_MM_TCC 1
 35 
 36 #ifdef FLENS_DEBUG_CLOSURES
 37 #   include <flens/blas/blaslogon.h>
 38 #else
 39 #   include <flens/blas/blaslogoff.h>
 40 #endif
 41 
 42 namespace flens { namespace blas {
 43 
 44 //== GeneralMatrix - GeneralMatrix products ====================================
 45 
 46 //== TriangularMatrix - GeneralMatrix products =================================
 47 template <typename ALPHA, typename MA, typename MB, typename BETA, typename MC>
 48 typename RestrictTo<IsTriangularMatrix<MA>::value &&
 49                     IsGeneralMatrix<MB>::value &&
 50                     IsGeneralMatrix<MC>::value,
 51          void>::Type
 52 trmm(Side side, Transpose transA, Transpose transB, const ALPHA &alpha,
 53      const MA &_A, const MB &_B, const BETA &beta, MC &C)
 54 {
 55     using namespace DEBUGCLOSURE;
 56 
 57     typedef typename Result<MA>::Type  RMA;
 58     typedef typename Result<MB>::Type  RMB;
 59     typedef typename Result<MC>::Type  RMC;
 60 
 61 //
 62 //  In non-closure debug mode we do not allow temporaries for A or B.
 63 //
 64 #   ifndef FLENS_DEBUG_CLOSURES
 65     static_assert(IsSame<RMA, typename Result<RMA>::Type>::value,
 66                   "temporary required");
 67     static_assert(IsSame<RMB, typename Result<RMB>::Type>::value,
 68                   "temporary required");
 69 #   endif
 70 
 71 //
 72 //  If _A or _B is a closure temporaries get created
 73 //
 74     FLENS_BLASLOG_TMP_TRON;
 75     const RMA &A = _A;
 76     const RMB &B = _B;
 77     FLENS_BLASLOG_TMP_TROFF;
 78 
 79 //
 80 //  If beta!=0 or transB!=NoTrans or A and C share the same memopry we need
 81 //  temporaries
 82 //
 83 #   ifndef FLENS_DEBUG_CLOSURES
 84     ASSERT(beta==BETA(0));
 85     ASSERT(transB==NoTrans);
 86     ASSERT(!identical(A, C));
 87 #   else
 88     if (transB!=NoTrans) {
 89 //
 90 //      apply op(B) and recall trmm
 91 //
 92         typename RMB::NoView _B;
 93         FLENS_BLASLOG_TMP_ADD(_B);
 94         copy(transB, B, _B);
 95         trmm(side, transA, NoTrans, alpha, A, _B, beta, C);
 96         FLENS_BLASLOG_TMP_REMOVE(_B, B);
 97         if (!IsSame<RMA, typename Result<RMA>::Type>::value) {
 98             FLENS_BLASLOG_TMP_REMOVE(A, _A);
 99         }
100         if (!IsSame<RMB, typename Result<RMB>::Type>::value) {
101             FLENS_BLASLOG_TMP_REMOVE(B, _B);
102         }
103         return;
104     }
105     if (identical(A,C)) {
106         FLENS_BLASLOG_IDENTICAL(A, C);
107         typename RMC::NoView _C;
108         FLENS_BLASLOG_TMP_ADD(_C);
109 
110         trmm(side, transA, NoTrans, alpha, A, B, beta, _C);
111         C = _C;
112 
113         FLENS_BLASLOG_TMP_REMOVE(_C, C);
114         return;
115     }
116     typename RMC::NoView tmpC;
117     if (beta!=BETA(0)) {
118         FLENS_BLASLOG_TMP_ADD(tmpC);
119         tmpC = C;
120     }
121 #   endif
122 
123 //
124 //  trmm can only compute C = A*C or C = C*A.  So if B and C are not identical
125 //  we need to copy
126 //
127     if (!identical(C, B)) {
128         C = B;
129     }
130     mm(side, transA, alpha, A, C);
131 
132 #   ifdef FLENS_DEBUG_CLOSURES
133     if (beta!=BETA(0)) {
134         C += beta*tmpC;
135         FLENS_BLASLOG_TMP_REMOVE(tmpC, C);
136     }
137     if (!IsSame<RMA, typename Result<RMA>::Type>::value) {
138         FLENS_BLASLOG_TMP_REMOVE(A, _A);
139     }
140     if (!IsSame<RMB, typename Result<RMB>::Type>::value) {
141         FLENS_BLASLOG_TMP_REMOVE(B, _B);
142     }
143 #   endif
144 }
145 
146 template <typename ALPHA, typename MA, typename MB, typename BETA, typename MC>
147 void
148 mm(Transpose transA, Transpose transB, const ALPHA &alpha,
149    const TriangularMatrix<MA> &A, const GeneralMatrix<MB> &B,
150    const BETA &beta, Matrix<MC> &C)
151 {
152     trmm(Left, transA, transB, alpha, A.impl(), B.impl(), beta, C.impl());
153 }
154 
155 template <typename ALPHA, typename MA, typename MB, typename BETA, typename MC>
156 void
157 mm(Transpose transA, Transpose transB, const ALPHA &alpha,
158    const GeneralMatrix<MA> &A, const TriangularMatrix<MB> &B,
159    const BETA &beta, Matrix<MC> &C)
160 {
161     trmm(Right, transB, transA, alpha, B.impl(), A.impl(), beta, C.impl());
162 }
163 
164 //== SymmetricMatrix - GeneralMatrix products ==================================
165 template <typename ALPHA, typename MA, typename MB, typename BETA, typename MC>
166 void
167 symm(Side side, Transpose transB, const ALPHA &alpha,
168      const SymmetricMatrix<MA> &_A, const GeneralMatrix<MB> &_B,
169      const BETA &beta, Matrix<MC> &C)
170 {
171     using namespace DEBUGCLOSURE;
172 
173     typedef typename Result<typename MA::Impl>::Type  RMA;
174     typedef typename Result<typename MB::Impl>::Type  RMB;
175     typedef typename Result<typename MC::Impl>::Type  RMC;
176 
177 //
178 //  In non-closure debug mode we do not allow temporaries for A or B.
179 //
180 #   ifndef FLENS_DEBUG_CLOSURES
181     static_assert(IsSame<RMA, typename Result<RMA>::Type>::value,
182                   "temporary required");
183     static_assert(IsSame<RMB, typename Result<RMB>::Type>::value,
184                   "temporary required");
185     ASSERT(transB==NoTrans);
186 #   endif
187 
188 //
189 //  If _A or _B is a closure temporaries get created
190 //
191     FLENS_BLASLOG_TMP_TRON;
192     const RMA &A = _A.impl();
193     const RMB &B = _B.impl();
194     FLENS_BLASLOG_TMP_TROFF;
195 
196 //
197 //  call (sy)mm
198 //
199 #   ifndef FLENS_DEBUG_CLOSURES
200     mm(side, alpha, A.impl(), B.impl(), beta, C.impl());
201 #   else
202 //
203 //  if transB is not NoTrans we need another temporary
204 //
205     if (transB==NoTrans) {
206         mm(side, alpha, A, B, beta, C.impl());
207     } else {
208         typename RMB::NoView _B;
209         FLENS_BLASLOG_TMP_ADD(_B);
210         copy(transB, B, _B);
211         mm(side, alpha, A, _B, beta, C.impl());
212         FLENS_BLASLOG_TMP_REMOVE(_B, B);
213     }
214 #   endif
215 
216 #   ifdef FLENS_DEBUG_CLOSURES
217     if (!IsSame<RMA, typename Result<RMA>::Type>::value) {
218         FLENS_BLASLOG_TMP_REMOVE(A, _A);
219     }
220     if (!IsSame<RMB, typename Result<RMB>::Type>::value) {
221         FLENS_BLASLOG_TMP_REMOVE(B, _B);
222     }
223 #   endif
224 }
225 
226 template <typename ALPHA, typename MA, typename MB, typename BETA, typename MC>
227 void
228 mm(Transpose transA, Transpose transB, const ALPHA &alpha,
229    const SymmetricMatrix<MA> &A, const GeneralMatrix<MB> &B,
230    const BETA &beta, Matrix<MC> &C)
231 {
232     ASSERT(transA==NoTrans || transA==Trans);
233     symm(Left, transB, alpha, A.impl(), B.impl(), beta, C.impl());
234 }
235 
236 template <typename ALPHA, typename MA, typename MB, typename BETA, typename MC>
237 void
238 mm(Transpose transA, Transpose transB, const ALPHA &alpha,
239    const GeneralMatrix<MA> &A, const SymmetricMatrix<MB> &B,
240    const BETA &beta, Matrix<MC> &C)
241 {
242     ASSERT(transA==NoTrans || transA==Trans);
243     symm(Right, transB, alpha, B.impl(), A.impl(), beta, C.impl());
244 }
245 
246 //== HermitianMatrix - GeneralMatrix products ==================================
247 
248 //== Matrix - Matrix products ==================================================
249 //
250 //  This gets called if everything else fails
251 //
252 #ifdef FLENS_DEBUG_CLOSURES
253 
254 template <typename ALPHA, typename MA, typename MB, typename BETA, typename MC>
255 void
256 mm(Transpose transA, Transpose transB, const ALPHA &alpha,
257    const Matrix<MA> &A, const Matrix<MB> &B,
258    const BETA &beta, Matrix<MC> &C)
259 {
260     FLENS_BLASLOG_BEGIN_GEMM(transA, transB, alpha, A.impl(),
261                              B.impl(), beta, C.impl());
262 //
263 //  We create temporaries of type GeMatrix for all matrices on the right
264 //  hand side.  If A and B can be converted to GeMatrix types this at
265 //  least does the desired compuation.
266 //
267     typedef typename MA::Impl::ElementType        TA;
268     typedef typename MB::Impl::ElementType        TB;
269 
270     typedef GeMatrix<FullStorage<TA, ColMajor> >  RMA;
271     typedef GeMatrix<FullStorage<TB, ColMajor> >  RMB;
272 
273     FLENS_BLASLOG_TMP_TRON;
274     const RMA &_A = A.impl();
275     const RMB &_B = B.impl();
276     FLENS_BLASLOG_TMP_TROFF;
277 
278     mm(transA, transB, alpha, _A, _B, beta, C.impl());
279 
280     FLENS_BLASLOG_TMP_REMOVE(_A, A);
281     FLENS_BLASLOG_TMP_REMOVE(_B, B);
282 
283     FLENS_BLASLOG_END;
284 }
285 
286 #endif
287 
288 } } // namespace blas, flens
289 
290 #endif // FLENS_BLAS_CLOSURES_MM_TCC
291