Template-ifizierung: gemm_blocked.h

Bei der Verallgemeinerung der geblockten GEMM-Implementierung ist folgendes zu beachten:

Wir verwenden folgende Techniken, um dies zu beheben:

Erster Kontakt mit Traits

Kurz gesagt ermöglichen es die sogenannten Traits, dass zur Compilezeit bekannte Typen und Konstanten auf andere Typen und Konstanten abgebildet werden.

Typen auf Typen abbilden

Betrachten wir folgendes Beispiel, das das Maximum von zwei Integer-Zahlen zurückgibt:

int
max(int a, int b)
{
    return (a>b) ? a : b;
}

Für Zahlen vom Typ double würde man dies analog programmieren:

double
max(double a, double b)
{
    return (a>b) ? a : b;
}

Und schnell hat man die Idee, dass sich dies mit Templates wie folgt verallgemeinern lässt:

template <typename T>
T
max(T a, T b)
{
    return (a>b) ? a : b;
}

Allerdings setzt dies voraus, dass beide Argumente den gleichen Typ besitzen. Das heisst der Funktionsaufruf

double m = max(1, 2.0);

erzeugt einen Compiler-Fehler, denn 1 hat den Typ int und 2.0 den Typ double. Dies hat aber nichts mit den Templates zu tun. Auch dann, wenn nur überladene Funktionen benutzt werden müsste man dazu eine Variante

double
max(int a, double b)
{
    return (a>b) ? a : b;
}

bereitstellen. Und für einen Aufruf max(2.0, 1) eine weitere. Deshalb wäre es natürlich interessant, eine Template-Funktion max zu implementieren, bei der die Argumente unterschiedliche Typen besitzen können. Die Frage ist aber, wie man hier den Rückgabe-Typ festlegen soll:

template <typename T, typename S>
/* ????? */
max(T a, S b)
{
    return (a>b) ? a : b;
}

Sind die Paramter T und S vom Typ int oder double, dann sollte der Rückgabe-Wert in Abhängigkeit von T und S wie folgt festgelegt werden:

T

S

Rückgabe-Typ

int

int

int

double

int

double

int

double

double

double

double

double

Dies ist offensichtlich eine Verknüpfungstabelle, die T und S auf den Rückgabe-Typen abbilden. Diese soll durch eine Template-Klasse Decl realisiert werden, die wir dann wie folgt verwenden können:

template <typename T, typename S>
typename Decl<T,S>::Type
max(T a, S b)
{
    return (a>b) ? a : b;
}

Erster Wurf

In einem ersten Wurf gehen wir völlig unelegant vor und codieren die vollständige Tabelle mit C++-Mitteln. Der Trick besteht darin, eine Template-Klasse für alle aufgelisteten Fälle zu spezialieren und dabei den Rückgabe-Typen durch einen typedef festzulegen:

template <typename T, typename S>
struct Decl
{
};

template <>
struct Decl<int, int>
{
    typedef int Type;
};

template <>
struct Decl<double, int>
{
    typedef double Type;
};

template <>
struct Decl<int, double>
{
    typedef double Type;
};

template <>
struct Decl<double, double>
{
    typedef double Type;
};

Um zu testen, ob diese Abbildung tatsächlich zur Compile-Zeit durchgeführt wird, kann folgendes Test-Programm benutzt werden:

#include <cstdio>
#include <decl.h>

template <typename T>
void
checkType(T)
{
    printf("Unknown type\n");
}

void
checkType(int)
{
    printf("int\n");
}

void
checkType(double)
{
    printf("double\n");
}

template <typename T, typename S>
typename Decl<T,S>::Type
max(T a, S b)
{
    return (a>b) ? a : b;
}

int
main()
{
    checkType(max(1,   1.0));
    checkType(max(1.0, 1  ));
    checkType(max(1,   1  ));
    checkType(max(1.0, 1.0));
}
$shell> g++ -Wall -I version1 -o test_decl test_decl.cc
$shell> ./test_decl
double
double
int
double

Zweiter Wurf

Zunächst stört uns, dass wir offensichtliche Fälle explizit codieren müssen und das ganze Konstrukt nur für int und double funktioniert. Sind beide Typen gleich, dann ist auch genau dies der Rückgabe-Typ. Dies können wir durch eine teilweise Spezialisierung umsetzen:

template <typename T, typename S>
struct Decl
{
};

// case where both types are the same
template <typename T>
struct Decl<T,T>
{
    typedef T Type;
};


template <>
struct Decl<double, int>
{
    typedef double Type;
};

template <>
struct Decl<int, double>
{
    typedef double Type;
};
$shell> g++ -Wall -I version2 -o test_decl test_decl.cc
$shell> ./test_decl
double
double
int
double

Dritter Wurf

Natürlich ist uns die Symmetrie der Tabelle nicht entgangen: Decl<T,S> und Decl<S,T> sollten immer den gleichen Typ definieren. Auch das können wir realisieren:

template <typename T, typename S>
struct Decl
{
    typedef typename Decl<S,T>::Type Type;
};

// case where both types are the same
template <typename T>
struct Decl<T,T>
{
    typedef T Type;
};

template <>
struct Decl<int, double>
{
    typedef double Type;
};
$shell> g++ -Wall -I version3 -o test_decl test_decl.cc
$shell> ./test_decl
double
double
int
double

Dadurch wurde der Programmieraufwand deutlich reduziert. Dennoch muss man diese Tabelle für alle anderen in Frage kommenenden Typen (long, float, ...) und Kombinationen mit viel Fleiß ergänzen. Aber seit C++11 wurde dies von anderen erledigt!

Ab C++11

In <type_traits> sind eine Vielzahl an nützlichen Trait-Klassen definiert (Link). Eine davon ist std::common_type<T,S>::type. Dies liefert den Typ, in den sowohl Instanzen von T als auch S konvertiert werden können:

#include <cstdio>
#include <type_traits>

template <typename T>
void
checkType(T)
{
    printf("Unknown type\n");
}

void
checkType(int)
{
    printf("int\n");
}

void
checkType(double)
{
    printf("double\n");
}

template <typename T, typename S>
typename std::common_type<T,S>::type
max(T a, S b)
{
    return (a>b) ? a : b;
}

int
main()
{
    checkType(max(1,   1.0));
    checkType(max(1.0, 1  ));
    checkType(max(1,   1  ));
    checkType(max(1.0, 1.0));
}
$shell> g++ -std=c++11 -Wall -I version1 -o test_decl2 test_decl2.cc
$shell> ./test_decl2
double
double
int
double

Typen auf Konstanten abbilden

Hier haben wir ein konkretes Beispiel. Der Element-Typ einer Matrix soll auf die Block-Dimensionen MC, NC, KC, MR und NR abgebildet werden. Beispielsweise:

Datentyp

MC

KC

NC

MR

NR

float

256

512

4096

8

8

double

256

256

4096

4

8

std::complex<float>

256

256

4096

4

8

std::complex<double>

256

128

4096

4

4

Sonst

64

64

256

2

2

Mit Traits werden wir folgendes ermöglichen:

template <typename T>
void
foo(/* ...*/)
{
    int MC = BlockSize<T>::MC;

    /* ... */
}

Dazu definieren wir eine Klasse BlockSize mit einem Template-Parameter für den Element-Typ. Dieser enthält in der allgemeinen Form die Default-Werte als statische Konstanten:

template <typename T>
struct BlockSize
{
    static const int MC = 64;
    static const int KC = 64;
    static const int NC = 256;
    static const int MR = 2;
    static const int NR = 2;
};

Durch Spezialisierungen kann obige Tabelle codiert werden:

template <>
struct BlockSize<float>
{
    static const int MC = 256;
    static const int KC = 512;
    static const int NC = 4096;
    static const int MR = 8;
    static const int NR = 8;
};

template <>
struct BlockSize<double>
{
    static const int MC = 256;
    static const int KC = 256;
    static const int NC = 4096;
    static const int MR = 4;
    static const int NR = 8;
};

/* ... */

Damit haben wir vorerst alle Werkzeuge, um die geblockte GEMM-Implementierung zu verallgemeinern.

Vorlage für gemm_blocked.h

#ifndef HPC_GEMM_BLOCKED_H
#define HPC_GEMM_BLOCKED_H 1

#include <complex>
#include <type_traits>
#include "ulmblas.h"

namespace blocked {

template <typename T>
struct BlockSize
{
    /* ... */
};

template <typename T, typename Index>
void
pack_A(Index mc, Index kc,
       const T *A, Index incRowA, Index incColA,
       T *p)
{
    /* ... */
}

template <typename T, typename Index>
void
pack_B(Index kc, Index nc,
       const T *B, Index incRowB, Index incColB,
       T *p)
{
    /* ... */
}

template <typename T, typename Index>
void
ugemm(Index kc, T alpha,
      const T *A, const T *B,
      T beta,
      T *C, Index incRowC, Index incColC)
{
    /* ... */
}

template <typename T, typename Index>
void
mgemm(Index mc, Index nc, Index kc,
      T alpha,
      const T *A, const T *B,
      T beta,
      T *C, Index incRowC, Index incColC)
{
    /* ... */
}

template <typename T, typename Index>
void
gemm(Index m, Index n, Index k,
     T alpha,
     const T *A, Index incRowA, Index incColA,
     const T *B, Index incRowB, Index incColB,
     T beta,
     T *C, Index incRowC, Index incColC)
{
    /* ... */
}

} // namespace blocked

#endif // HPC_GEMM_BLOCKED_H