GEMM Micro Kernel in Assembler Code
Content |
Exercise
Use the GEMM Micro Kernel given below which is written entirely in assembly code. You need to add a forward declaration in ulmblas.c.
# void # dgemm_micro_avx_4x8(uint64_t k, # double alpha, # double *A, # double *B, # double beta, # double *C, int64_t incRowC, int64_t incColC); # # Arguments are passed in registers as follows: # # %rdi : uint64_t k # %xmm0 : double alpha # %rsi : double *A # %rdx : double *B # %xmm1 : double beta # %rcx : double *C # %r8 : int64_t incRowC # %r9 : int64_t incColC # .globl dgemm_micro_avx_4x8 dgemm_micro_avx_4x8: vmovlpd %xmm0, -8(%rsp) # push alpha vmovlpd %xmm1, -16(%rsp) # push beta movq %rbx, -24(%rsp) # push %rbx movq %rbp, -32(%rsp) # push %rbp # Initialize C_ vxorpd %ymm8, %ymm8, %ymm8 # C_[0,0:3] vxorpd %ymm9, %ymm9, %ymm9 # C_[0,4:7] vxorpd %ymm10, %ymm10, %ymm10 # C_[1,0:3] vxorpd %ymm11, %ymm11, %ymm11 # C_[1,4:7] vxorpd %ymm12, %ymm12, %ymm12 # C_[2,0:3] vxorpd %ymm13, %ymm13, %ymm13 # C_[2,4:7] vxorpd %ymm14, %ymm14, %ymm14 # C_[3,0:3] vxorpd %ymm15, %ymm15, %ymm15 # C_[3,4:7] jmp .loop_check .loop: vbroadcastsd 0 * 8(%rsi), %ymm0 # A[0] vbroadcastsd 1 * 8(%rsi), %ymm1 # A[1] vbroadcastsd 2 * 8(%rsi), %ymm2 # A[2] vbroadcastsd 3 * 8(%rsi), %ymm3 # A[3] vmovapd 0 * 32(%rdx), %ymm4 # B[0:3] vmovapd 1 * 32(%rdx), %ymm5 # B[4:7] vmulpd %ymm0, %ymm4, %ymm6 vaddpd %ymm6, %ymm8, %ymm8 # C_[0,0:3] += A[0]*B[0:3] vmulpd %ymm0, %ymm5, %ymm7 vaddpd %ymm7, %ymm9, %ymm9 # C_[0,4:7] += A[0]*B[4:7] vmulpd %ymm1, %ymm4, %ymm6 vaddpd %ymm6, %ymm10, %ymm10 # C_[1,0:3] += A[1]*B[0:3] vmulpd %ymm1, %ymm5, %ymm7 vaddpd %ymm7, %ymm11, %ymm11 # C_[1,4:7] += A[1]*B[4:7] vmulpd %ymm2, %ymm4, %ymm6 vaddpd %ymm6, %ymm12, %ymm12 # C_[2,0:3] += A[2]*B[0:3] vmulpd %ymm2, %ymm5, %ymm7 vaddpd %ymm7, %ymm13, %ymm13 # C_[2,4:7] += A[2]*B[4:7] vmulpd %ymm3, %ymm4, %ymm6 vaddpd %ymm6, %ymm14, %ymm14 # C_[3,0:3] += A[3]*B[0:3] vmulpd %ymm3, %ymm5, %ymm7 vaddpd %ymm7, %ymm15, %ymm15 # C_[3,4:7] += A[3]*B[4:7] addq $4*8, %rsi # A += 4*sizeof(double) addq $2*4*8, %rdx # B += 2*4*sizeof(double) decq %rdi # --k .loop_check: testq %rdi, %rdi jne .loop # if (k!=0) # goto .loop; vbroadcastsd -8(%rsp), %ymm6 # load alpha vbroadcastsd -16(%rsp), %ymm7 # load beta vmulpd %ymm6, %ymm8, %ymm8 # C_[0,0:3] *= alpha vmulpd %ymm6, %ymm9, %ymm9 # C_[0,4:7] *= alpha vmulpd %ymm6, %ymm10, %ymm10 # C_[1,0:3] *= alpha vmulpd %ymm6, %ymm11, %ymm11 # C_[1,4:7] *= alpha vmulpd %ymm6, %ymm12, %ymm12 # C_[2,0:3] *= alpha vmulpd %ymm6, %ymm13, %ymm13 # C_[2,4:7] *= alpha vmulpd %ymm6, %ymm14, %ymm14 # C_[3,0:3] *= alpha vmulpd %ymm6, %ymm15, %ymm15 # C_[3,4:7] *= alpha leaq (,%r8,8), %r8 # load incRow*sizeof(double) leaq (,%r9,8), %r9 # load incCol*sizeof(double) leaq (,%r9,2), %r10 # load 2*incCol*sizeof(double) leaq (%r10,%r9), %r11 # load 3*incCol*sizeof(double) leaq (%rcx,%r10,2), %rdx # load C + 4*incCol*sizeof(double) # if (beta==0) goto .beta_zero; vxorpd %ymm0, %ymm0, %ymm0 vucomisd %xmm0, %xmm7 je .beta_zero # case: beta != 0 # # Update C(0,:) # vmovlpd (%rcx), %xmm0, %xmm0 # load C[0,0:3] vmovhpd (%rcx,%r9), %xmm0, %xmm0 vmovlpd (%rcx,%r10), %xmm1, %xmm1 vmovhpd (%rcx,%r11), %xmm1, %xmm1 vmovlpd (%rdx), %xmm2, %xmm2 # load C[0,4:7] vmovhpd (%rdx,%r9), %xmm2, %xmm2 vmovlpd (%rdx,%r10), %xmm3, %xmm3 vmovhpd (%rdx,%r11), %xmm3, %xmm3 vmulpd %xmm7, %xmm0, %xmm0 # scale by beta vmulpd %xmm7, %xmm1, %xmm1 vmulpd %xmm7, %xmm2, %xmm2 vmulpd %xmm7, %xmm3, %xmm3 vextractf128 $1, %ymm8, %xmm4 vextractf128 $1, %ymm9, %xmm5 vaddpd %xmm0, %xmm8, %xmm0 # add C_ vaddpd %xmm1, %xmm4, %xmm1 vaddpd %xmm2, %xmm9, %xmm2 vaddpd %xmm3, %xmm5, %xmm3 vmovlpd %xmm0, (%rcx) # store C[0,0:3] vmovhpd %xmm0, (%rcx,%r9) vmovlpd %xmm1, (%rcx,%r10) vmovhpd %xmm1, (%rcx,%r11) vmovlpd %xmm2, (%rdx) # store C[0,4:7] vmovhpd %xmm2, (%rdx,%r9) vmovlpd %xmm3, (%rdx,%r10) vmovhpd %xmm3, (%rdx,%r11) # # Update C(1,:) # addq %r8, %rcx addq %r8, %rdx vmovlpd (%rcx), %xmm0, %xmm0 # load C[1,0:3] vmovhpd (%rcx,%r9), %xmm0, %xmm0 vmovlpd (%rcx,%r10), %xmm1, %xmm1 vmovhpd (%rcx,%r11), %xmm1, %xmm1 vmovlpd (%rdx), %xmm2, %xmm2 # load C[1,4:7] vmovhpd (%rdx,%r9), %xmm2, %xmm2 vmovlpd (%rdx,%r10), %xmm3, %xmm3 vmovhpd (%rdx,%r11), %xmm3, %xmm3 vmulpd %xmm7, %xmm0, %xmm0 # scale by beta vmulpd %xmm7, %xmm1, %xmm1 vmulpd %xmm7, %xmm2, %xmm2 vmulpd %xmm7, %xmm3, %xmm3 vextractf128 $1, %ymm10, %xmm4 vextractf128 $1, %ymm11, %xmm5 vaddpd %xmm0, %xmm10, %xmm0 # add C_ vaddpd %xmm1, %xmm4, %xmm1 vaddpd %xmm2, %xmm11, %xmm2 vaddpd %xmm3, %xmm5, %xmm3 vmovlpd %xmm0, (%rcx) # store C[1,0:3] vmovhpd %xmm0, (%rcx,%r9) vmovlpd %xmm1, (%rcx,%r10) vmovhpd %xmm1, (%rcx,%r11) vmovlpd %xmm2, (%rdx) # store C[1,4:7] vmovhpd %xmm2, (%rdx,%r9) vmovlpd %xmm3, (%rdx,%r10) vmovhpd %xmm3, (%rdx,%r11) # # Update C(2,:) # addq %r8, %rcx addq %r8, %rdx vmovlpd (%rcx), %xmm0, %xmm0 # load C[2,0:3] vmovhpd (%rcx,%r9), %xmm0, %xmm0 vmovlpd (%rcx,%r10), %xmm1, %xmm1 vmovhpd (%rcx,%r11), %xmm1, %xmm1 vmovlpd (%rdx), %xmm2, %xmm2 # load C[2,4:7] vmovhpd (%rdx,%r9), %xmm2, %xmm2 vmovlpd (%rdx,%r10), %xmm3, %xmm3 vmovhpd (%rdx,%r11), %xmm3, %xmm3 vmulpd %xmm7, %xmm0, %xmm0 # scale by beta vmulpd %xmm7, %xmm1, %xmm1 vmulpd %xmm7, %xmm2, %xmm2 vmulpd %xmm7, %xmm3, %xmm3 vextractf128 $1, %ymm12, %xmm4 vextractf128 $1, %ymm13, %xmm5 vaddpd %xmm0, %xmm12, %xmm0 # add C_ vaddpd %xmm1, %xmm4, %xmm1 vaddpd %xmm2, %xmm13, %xmm2 vaddpd %xmm3, %xmm5, %xmm3 vmovlpd %xmm0, (%rcx) # store C[2,0:3] vmovhpd %xmm0, (%rcx,%r9) vmovlpd %xmm1, (%rcx,%r10) vmovhpd %xmm1, (%rcx,%r11) vmovlpd %xmm2, (%rdx) # store C[2,4:7] vmovhpd %xmm2, (%rdx,%r9) vmovlpd %xmm3, (%rdx,%r10) vmovhpd %xmm3, (%rdx,%r11) # # Update C(3,:) # addq %r8, %rcx addq %r8, %rdx vmovlpd (%rcx), %xmm0, %xmm0 # load C[2,0:3] vmovhpd (%rcx,%r9), %xmm0, %xmm0 vmovlpd (%rcx,%r10), %xmm1, %xmm1 vmovhpd (%rcx,%r11), %xmm1, %xmm1 vmovlpd (%rdx), %xmm2, %xmm2 # load C[2,4:7] vmovhpd (%rdx,%r9), %xmm2, %xmm2 vmovlpd (%rdx,%r10), %xmm3, %xmm3 vmovhpd (%rdx,%r11), %xmm3, %xmm3 vmulpd %xmm7, %xmm0, %xmm0 # scale by beta vmulpd %xmm7, %xmm1, %xmm1 vmulpd %xmm7, %xmm2, %xmm2 vmulpd %xmm7, %xmm3, %xmm3 vextractf128 $1, %ymm14, %xmm4 vextractf128 $1, %ymm15, %xmm5 vaddpd %xmm0, %xmm14, %xmm0 # add C_ vaddpd %xmm1, %xmm4, %xmm1 vaddpd %xmm2, %xmm15, %xmm2 vaddpd %xmm3, %xmm5, %xmm3 vmovlpd %xmm0, (%rcx) # store C[3,0:3] vmovhpd %xmm0, (%rcx,%r9) vmovlpd %xmm1, (%rcx,%r10) vmovhpd %xmm1, (%rcx,%r11) vmovlpd %xmm2, (%rdx) # store C[3,4:7] vmovhpd %xmm2, (%rdx,%r9) vmovlpd %xmm3, (%rdx,%r10) vmovhpd %xmm3, (%rdx,%r11) jmp .done .beta_zero: # case: beta == 0 # # Update C(0,:) # vextractf128 $1, %ymm8, %xmm4 vextractf128 $1, %ymm9, %xmm5 vmovlpd %xmm8, (%rcx) # store C[0,0:3] vmovhpd %xmm8, (%rcx,%r9) vmovlpd %xmm4, (%rcx,%r10) vmovhpd %xmm4, (%rcx,%r11) vmovlpd %xmm9, (%rdx) # store C[0,4:7] vmovhpd %xmm9, (%rdx,%r9) vmovlpd %xmm5, (%rdx,%r10) vmovhpd %xmm5, (%rdx,%r11) # # Update C(1,:) # addq %r8, %rcx addq %r8, %rdx vextractf128 $1, %ymm10, %xmm4 vextractf128 $1, %ymm11, %xmm5 vmovlpd %xmm10, (%rcx) # store C[1,0:3] vmovhpd %xmm10, (%rcx,%r9) vmovlpd %xmm4, (%rcx,%r10) vmovhpd %xmm4, (%rcx,%r11) vmovlpd %xmm11, (%rdx) # store C[1,4:7] vmovhpd %xmm11, (%rdx,%r9) vmovlpd %xmm5, (%rdx,%r10) vmovhpd %xmm5, (%rdx,%r11) # # Update C(2,:) # addq %r8, %rcx addq %r8, %rdx vextractf128 $1, %ymm12, %xmm4 vextractf128 $1, %ymm13, %xmm5 vmovlpd %xmm12, (%rcx) # store C[2,0:3] vmovhpd %xmm12, (%rcx,%r9) vmovlpd %xmm4, (%rcx,%r10) vmovhpd %xmm4, (%rcx,%r11) vmovlpd %xmm13, (%rdx) # store C[2,4:7] vmovhpd %xmm13, (%rdx,%r9) vmovlpd %xmm5, (%rdx,%r10) vmovhpd %xmm5, (%rdx,%r11) # # Update C(3,:) # addq %r8, %rcx addq %r8, %rdx vextractf128 $1, %ymm14, %xmm4 vextractf128 $1, %ymm15, %xmm5 vmovlpd %xmm14, (%rcx) # store C[3,0:3] vmovhpd %xmm14, (%rcx,%r9) vmovlpd %xmm4, (%rcx,%r10) vmovhpd %xmm4, (%rcx,%r11) vmovlpd %xmm15, (%rdx) # store C[3,4:7] vmovhpd %xmm15, (%rdx,%r9) vmovlpd %xmm5, (%rdx,%r10) vmovhpd %xmm5, (%rdx,%r11) .done: movq -24(%rsp), %rbx # pop %rbx movq -32(%rsp), %rbp # pop %rbp retq