GEMM Micro Kernel in Assembler Code

Content

Exercise

Use the GEMM Micro Kernel given below which is written entirely in assembly code. You need to add a forward declaration in ulmblas.c.

#  void
#  dgemm_micro_avx_4x8(uint64_t   k,
#                      double alpha,
#                      double *A,
#                      double *B,
#                      double beta,
#                      double *C, int64_t incRowC, int64_t incColC);
#
#  Arguments are passed in registers as follows:
#
#           %rdi        :       uint64_t       k
#           %xmm0       :       double         alpha
#           %rsi        :       double         *A
#           %rdx        :       double         *B
#           %xmm1       :       double         beta
#           %rcx        :       double         *C
#           %r8         :       int64_t        incRowC
#           %r9         :       int64_t        incColC
#

.globl dgemm_micro_avx_4x8
dgemm_micro_avx_4x8:

        vmovlpd         %xmm0,       -8(%rsp)           # push alpha
        vmovlpd         %xmm1,      -16(%rsp)           # push beta
        movq            %rbx,       -24(%rsp)           # push %rbx
        movq            %rbp,       -32(%rsp)           # push %rbp

                                                                # Initialize C_
        vxorpd          %ymm8,          %ymm8,          %ymm8   #   C_[0,0:3]
        vxorpd          %ymm9,          %ymm9,          %ymm9   #   C_[0,4:7]
        vxorpd          %ymm10,         %ymm10,         %ymm10  #   C_[1,0:3]
        vxorpd          %ymm11,         %ymm11,         %ymm11  #   C_[1,4:7]
        vxorpd          %ymm12,         %ymm12,         %ymm12  #   C_[2,0:3]
        vxorpd          %ymm13,         %ymm13,         %ymm13  #   C_[2,4:7]
        vxorpd          %ymm14,         %ymm14,         %ymm14  #   C_[3,0:3]
        vxorpd          %ymm15,         %ymm15,         %ymm15  #   C_[3,4:7]

        jmp     .loop_check

.loop:
        vbroadcastsd              0 * 8(%rsi),          %ymm0   # A[0]
        vbroadcastsd              1 * 8(%rsi),          %ymm1   # A[1]
        vbroadcastsd              2 * 8(%rsi),          %ymm2   # A[2]
        vbroadcastsd              3 * 8(%rsi),          %ymm3   # A[3]
        vmovapd                  0 * 32(%rdx),          %ymm4   # B[0:3]
        vmovapd                  1 * 32(%rdx),          %ymm5   # B[4:7]

        vmulpd          %ymm0,          %ymm4,          %ymm6
        vaddpd          %ymm6,          %ymm8,          %ymm8   # C_[0,0:3] += A[0]*B[0:3]
        vmulpd          %ymm0,          %ymm5,          %ymm7
        vaddpd          %ymm7,          %ymm9,          %ymm9   # C_[0,4:7] += A[0]*B[4:7]

        vmulpd          %ymm1,          %ymm4,          %ymm6
        vaddpd          %ymm6,          %ymm10,         %ymm10  # C_[1,0:3] += A[1]*B[0:3]
        vmulpd          %ymm1,          %ymm5,          %ymm7
        vaddpd          %ymm7,          %ymm11,         %ymm11  # C_[1,4:7] += A[1]*B[4:7]

        vmulpd          %ymm2,          %ymm4,          %ymm6
        vaddpd          %ymm6,          %ymm12,         %ymm12  # C_[2,0:3] += A[2]*B[0:3]
        vmulpd          %ymm2,          %ymm5,          %ymm7
        vaddpd          %ymm7,          %ymm13,         %ymm13  # C_[2,4:7] += A[2]*B[4:7]

        vmulpd          %ymm3,          %ymm4,          %ymm6
        vaddpd          %ymm6,          %ymm14,         %ymm14  # C_[3,0:3] += A[3]*B[0:3]
        vmulpd          %ymm3,          %ymm5,          %ymm7
        vaddpd          %ymm7,          %ymm15,         %ymm15  # C_[3,4:7] += A[3]*B[4:7]


        addq            $4*8,           %rsi            # A += 4*sizeof(double)
        addq            $2*4*8,         %rdx            # B += 2*4*sizeof(double)

        decq            %rdi                            # --k

.loop_check:
        testq           %rdi,           %rdi
        jne             .loop                           # if (k!=0)
                                                        #   goto .loop;

        vbroadcastsd                 -8(%rsp),          %ymm6   # load alpha
        vbroadcastsd                -16(%rsp),          %ymm7   # load beta

        vmulpd          %ymm6,          %ymm8,          %ymm8   # C_[0,0:3] *= alpha
        vmulpd          %ymm6,          %ymm9,          %ymm9   # C_[0,4:7] *= alpha
        vmulpd          %ymm6,          %ymm10,         %ymm10  # C_[1,0:3] *= alpha
        vmulpd          %ymm6,          %ymm11,         %ymm11  # C_[1,4:7] *= alpha
        vmulpd          %ymm6,          %ymm12,         %ymm12  # C_[2,0:3] *= alpha
        vmulpd          %ymm6,          %ymm13,         %ymm13  # C_[2,4:7] *= alpha
        vmulpd          %ymm6,          %ymm14,         %ymm14  # C_[3,0:3] *= alpha
        vmulpd          %ymm6,          %ymm15,         %ymm15  # C_[3,4:7] *= alpha

        leaq            (,%r8,8),       %r8             # load incRow*sizeof(double)
        leaq            (,%r9,8),       %r9             # load incCol*sizeof(double)

        leaq            (,%r9,2),       %r10            # load 2*incCol*sizeof(double)
        leaq            (%r10,%r9),     %r11            # load 3*incCol*sizeof(double)
        leaq            (%rcx,%r10,2),  %rdx            # load C + 4*incCol*sizeof(double)

        # if (beta==0) goto .beta_zero;

        vxorpd          %ymm0,          %ymm0,          %ymm0
        vucomisd        %xmm0,          %xmm7
        je              .beta_zero

# case: beta != 0

#
#       Update C(0,:)
#
        vmovlpd         (%rcx),         %xmm0,          %xmm0   # load C[0,0:3]
        vmovhpd         (%rcx,%r9),     %xmm0,          %xmm0
        vmovlpd         (%rcx,%r10),    %xmm1,          %xmm1
        vmovhpd         (%rcx,%r11),    %xmm1,          %xmm1
        vmovlpd         (%rdx),         %xmm2,          %xmm2   # load C[0,4:7]
        vmovhpd         (%rdx,%r9),     %xmm2,          %xmm2
        vmovlpd         (%rdx,%r10),    %xmm3,          %xmm3
        vmovhpd         (%rdx,%r11),    %xmm3,          %xmm3

        vmulpd          %xmm7,          %xmm0,          %xmm0   # scale by beta
        vmulpd          %xmm7,          %xmm1,          %xmm1
        vmulpd          %xmm7,          %xmm2,          %xmm2
        vmulpd          %xmm7,          %xmm3,          %xmm3

        vextractf128    $1,             %ymm8,          %xmm4
        vextractf128    $1,             %ymm9,          %xmm5

        vaddpd          %xmm0,          %xmm8,          %xmm0   # add C_
        vaddpd          %xmm1,          %xmm4,          %xmm1
        vaddpd          %xmm2,          %xmm9,          %xmm2
        vaddpd          %xmm3,          %xmm5,          %xmm3

        vmovlpd         %xmm0,          (%rcx)                  # store C[0,0:3]
        vmovhpd         %xmm0,          (%rcx,%r9)
        vmovlpd         %xmm1,          (%rcx,%r10)
        vmovhpd         %xmm1,          (%rcx,%r11)
        vmovlpd         %xmm2,          (%rdx)                  # store C[0,4:7]
        vmovhpd         %xmm2,          (%rdx,%r9)
        vmovlpd         %xmm3,          (%rdx,%r10)
        vmovhpd         %xmm3,          (%rdx,%r11)

#
#       Update C(1,:)
#
        addq            %r8,            %rcx
        addq            %r8,            %rdx

        vmovlpd         (%rcx),         %xmm0,          %xmm0   # load C[1,0:3]
        vmovhpd         (%rcx,%r9),     %xmm0,          %xmm0
        vmovlpd         (%rcx,%r10),    %xmm1,          %xmm1
        vmovhpd         (%rcx,%r11),    %xmm1,          %xmm1
        vmovlpd         (%rdx),         %xmm2,          %xmm2   # load C[1,4:7]
        vmovhpd         (%rdx,%r9),     %xmm2,          %xmm2
        vmovlpd         (%rdx,%r10),    %xmm3,          %xmm3
        vmovhpd         (%rdx,%r11),    %xmm3,          %xmm3

        vmulpd          %xmm7,          %xmm0,          %xmm0   # scale by beta
        vmulpd          %xmm7,          %xmm1,          %xmm1
        vmulpd          %xmm7,          %xmm2,          %xmm2
        vmulpd          %xmm7,          %xmm3,          %xmm3

        vextractf128    $1,             %ymm10,         %xmm4
        vextractf128    $1,             %ymm11,         %xmm5

        vaddpd          %xmm0,          %xmm10,         %xmm0   # add C_
        vaddpd          %xmm1,          %xmm4,          %xmm1
        vaddpd          %xmm2,          %xmm11,         %xmm2
        vaddpd          %xmm3,          %xmm5,          %xmm3

        vmovlpd         %xmm0,          (%rcx)                  # store C[1,0:3]
        vmovhpd         %xmm0,          (%rcx,%r9)
        vmovlpd         %xmm1,          (%rcx,%r10)
        vmovhpd         %xmm1,          (%rcx,%r11)
        vmovlpd         %xmm2,          (%rdx)                  # store C[1,4:7]
        vmovhpd         %xmm2,          (%rdx,%r9)
        vmovlpd         %xmm3,          (%rdx,%r10)
        vmovhpd         %xmm3,          (%rdx,%r11)

#
#       Update C(2,:)
#
        addq            %r8,            %rcx
        addq            %r8,            %rdx

        vmovlpd         (%rcx),         %xmm0,          %xmm0   # load C[2,0:3]
        vmovhpd         (%rcx,%r9),     %xmm0,          %xmm0
        vmovlpd         (%rcx,%r10),    %xmm1,          %xmm1
        vmovhpd         (%rcx,%r11),    %xmm1,          %xmm1
        vmovlpd         (%rdx),         %xmm2,          %xmm2   # load C[2,4:7]
        vmovhpd         (%rdx,%r9),     %xmm2,          %xmm2
        vmovlpd         (%rdx,%r10),    %xmm3,          %xmm3
        vmovhpd         (%rdx,%r11),    %xmm3,          %xmm3

        vmulpd          %xmm7,          %xmm0,          %xmm0   # scale by beta
        vmulpd          %xmm7,          %xmm1,          %xmm1
        vmulpd          %xmm7,          %xmm2,          %xmm2
        vmulpd          %xmm7,          %xmm3,          %xmm3

        vextractf128    $1,             %ymm12,         %xmm4
        vextractf128    $1,             %ymm13,         %xmm5

        vaddpd          %xmm0,          %xmm12,         %xmm0   # add C_
        vaddpd          %xmm1,          %xmm4,          %xmm1
        vaddpd          %xmm2,          %xmm13,         %xmm2
        vaddpd          %xmm3,          %xmm5,          %xmm3

        vmovlpd         %xmm0,          (%rcx)                  # store C[2,0:3]
        vmovhpd         %xmm0,          (%rcx,%r9)
        vmovlpd         %xmm1,          (%rcx,%r10)
        vmovhpd         %xmm1,          (%rcx,%r11)
        vmovlpd         %xmm2,          (%rdx)                  # store C[2,4:7]
        vmovhpd         %xmm2,          (%rdx,%r9)
        vmovlpd         %xmm3,          (%rdx,%r10)
        vmovhpd         %xmm3,          (%rdx,%r11)

#
#       Update C(3,:)
#
        addq            %r8,            %rcx
        addq            %r8,            %rdx

        vmovlpd         (%rcx),         %xmm0,          %xmm0   # load C[2,0:3]
        vmovhpd         (%rcx,%r9),     %xmm0,          %xmm0
        vmovlpd         (%rcx,%r10),    %xmm1,          %xmm1
        vmovhpd         (%rcx,%r11),    %xmm1,          %xmm1
        vmovlpd         (%rdx),         %xmm2,          %xmm2   # load C[2,4:7]
        vmovhpd         (%rdx,%r9),     %xmm2,          %xmm2
        vmovlpd         (%rdx,%r10),    %xmm3,          %xmm3
        vmovhpd         (%rdx,%r11),    %xmm3,          %xmm3

        vmulpd          %xmm7,          %xmm0,          %xmm0   # scale by beta
        vmulpd          %xmm7,          %xmm1,          %xmm1
        vmulpd          %xmm7,          %xmm2,          %xmm2
        vmulpd          %xmm7,          %xmm3,          %xmm3

        vextractf128    $1,             %ymm14,         %xmm4
        vextractf128    $1,             %ymm15,         %xmm5

        vaddpd          %xmm0,          %xmm14,         %xmm0   # add C_
        vaddpd          %xmm1,          %xmm4,          %xmm1
        vaddpd          %xmm2,          %xmm15,         %xmm2
        vaddpd          %xmm3,          %xmm5,          %xmm3

        vmovlpd         %xmm0,          (%rcx)                  # store C[3,0:3]
        vmovhpd         %xmm0,          (%rcx,%r9)
        vmovlpd         %xmm1,          (%rcx,%r10)
        vmovhpd         %xmm1,          (%rcx,%r11)
        vmovlpd         %xmm2,          (%rdx)                  # store C[3,4:7]
        vmovhpd         %xmm2,          (%rdx,%r9)
        vmovlpd         %xmm3,          (%rdx,%r10)
        vmovhpd         %xmm3,          (%rdx,%r11)

        jmp             .done

        .beta_zero:

# case: beta == 0

#
#       Update C(0,:)
#
        vextractf128    $1,             %ymm8,          %xmm4
        vextractf128    $1,             %ymm9,          %xmm5

        vmovlpd         %xmm8,          (%rcx)                  # store C[0,0:3]
        vmovhpd         %xmm8,          (%rcx,%r9)
        vmovlpd         %xmm4,          (%rcx,%r10)
        vmovhpd         %xmm4,          (%rcx,%r11)
        vmovlpd         %xmm9,          (%rdx)                  # store C[0,4:7]
        vmovhpd         %xmm9,          (%rdx,%r9)
        vmovlpd         %xmm5,          (%rdx,%r10)
        vmovhpd         %xmm5,          (%rdx,%r11)

#
#       Update C(1,:)
#
        addq            %r8,            %rcx
        addq            %r8,            %rdx

        vextractf128    $1,             %ymm10,         %xmm4
        vextractf128    $1,             %ymm11,         %xmm5

        vmovlpd         %xmm10,         (%rcx)                  # store C[1,0:3]
        vmovhpd         %xmm10,         (%rcx,%r9)
        vmovlpd         %xmm4,          (%rcx,%r10)
        vmovhpd         %xmm4,          (%rcx,%r11)
        vmovlpd         %xmm11,         (%rdx)                  # store C[1,4:7]
        vmovhpd         %xmm11,         (%rdx,%r9)
        vmovlpd         %xmm5,          (%rdx,%r10)
        vmovhpd         %xmm5,          (%rdx,%r11)

#
#       Update C(2,:)
#
        addq            %r8,            %rcx
        addq            %r8,            %rdx

        vextractf128    $1,             %ymm12,         %xmm4
        vextractf128    $1,             %ymm13,         %xmm5

        vmovlpd         %xmm12,         (%rcx)                  # store C[2,0:3]
        vmovhpd         %xmm12,         (%rcx,%r9)
        vmovlpd         %xmm4,          (%rcx,%r10)
        vmovhpd         %xmm4,          (%rcx,%r11)
        vmovlpd         %xmm13,         (%rdx)                  # store C[2,4:7]
        vmovhpd         %xmm13,         (%rdx,%r9)
        vmovlpd         %xmm5,          (%rdx,%r10)
        vmovhpd         %xmm5,          (%rdx,%r11)

#
#       Update C(3,:)
#
        addq            %r8,            %rcx
        addq            %r8,            %rdx

        vextractf128    $1,             %ymm14,         %xmm4
        vextractf128    $1,             %ymm15,         %xmm5

        vmovlpd         %xmm14,         (%rcx)                  # store C[3,0:3]
        vmovhpd         %xmm14,         (%rcx,%r9)
        vmovlpd         %xmm4,          (%rcx,%r10)
        vmovhpd         %xmm4,          (%rcx,%r11)
        vmovlpd         %xmm15,         (%rdx)                  # store C[3,4:7]
        vmovhpd         %xmm15,         (%rdx,%r9)
        vmovlpd         %xmm5,          (%rdx,%r10)
        vmovhpd         %xmm5,          (%rdx,%r11)


        .done:
        movq        -24(%rsp),      %rbx        # pop %rbx
        movq        -32(%rsp),      %rbp        # pop %rbp

        retq