Skip to content

Commit

Permalink
Add inner loop unrolling for f32 GEMM on aarch64
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707440292
  • Loading branch information
alankelly authored and xnnpack-bot committed Jan 10, 2025
1 parent 43578af commit 68cb8e2
Show file tree
Hide file tree
Showing 138 changed files with 12,628 additions and 782 deletions.
435 changes: 365 additions & 70 deletions bench/f32-gemm-minmax.cc

Large diffs are not rendered by default.

411 changes: 323 additions & 88 deletions bench/qd8-f32-qc8w-gemm.cc

Large diffs are not rendered by default.

55 changes: 50 additions & 5 deletions cmake/gen/aarch64_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -123,38 +123,67 @@ SET(NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS
src/f32-dwconv/f32-dwconv-9p4c-minmax-asm-aarch64-neonfma.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neon-ld128-acc2-prfm.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neon-ld128-acc2.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32-2.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-2.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc2-prfm.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc2.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc4-prfm.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc4.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-prfm.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128-2.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128-acc2-prfm.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128-acc4-prfm.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128-prfm.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128.S
src/f32-gemm/gen/f32-gemm-1x12-minmax-asm-aarch64-neonfma-cortex-a53.S
src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld128.S
src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32-2.S
src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld64-2.S
src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld128-2.S
src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld128.S
src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32-2.S
src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld64-2.S
src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld128-2.S
src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld128.S
src/f32-gemm/gen/f32-gemm-4x1-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-4x1-minmax-asm-aarch64-neonfma-ld128.S
src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-cortex-a75-prfm.S
src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-cortex-a75.S
src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32-2.S
src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld64-2.S
src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld128-2.S
src/f32-gemm/gen/f32-gemm-4x12-minmax-asm-aarch64-neonfma-cortex-a53.S
src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld128.S
src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-cortex-a75-prfm.S
src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-cortex-a75.S
src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32-2.S
src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld64-2.S
src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld128-2.S
src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld128.S
src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-cortex-a75.S
src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-ld32-2.S
src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-ld64-2.S
src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-ld128-2.S
src/f32-gemm/gen/f32-gemm-7x8-minmax-asm-aarch64-neonfma-ld32-2.S
src/f32-gemm/gen/f32-gemm-7x8-minmax-asm-aarch64-neonfma-ld64-2.S
src/f32-gemm/gen/f32-gemm-7x8-minmax-asm-aarch64-neonfma-ld128-2.S
src/f32-gemm/gen/f32-gemm-8x8-minmax-asm-aarch64-neonfma-ld32-2.S
src/f32-gemm/gen/f32-gemm-8x8-minmax-asm-aarch64-neonfma-ld64-2.S
src/f32-gemm/gen/f32-gemm-8x8-minmax-asm-aarch64-neonfma-ld128-2.S
src/f32-gemm/gen/f32-gemm-goi-1x8-minmax-asm-aarch64-neonfma-ld128-prfm.S
src/f32-gemm/gen/f32-gemm-goi-1x8-minmax-asm-aarch64-neonfma-ld128.S
src/f32-gemm/gen/f32-gemm-goi-4x8-minmax-asm-aarch64-neonfma-ld128.S
Expand Down Expand Up @@ -239,13 +268,29 @@ SET(NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS
src/f32-qc8w-gemm/gen/f32-qc8w-gemm-4x8-minmax-asm-aarch64-neonfma-ld128.S
src/f32-qc8w-gemm/gen/f32-qc8w-gemm-6x8-minmax-asm-aarch64-neonfma-ld64.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld32.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld64.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld128.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld32.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld64.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld128.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld32.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld64.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld128.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld32.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld64.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld128.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld32.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld64.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld128.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld32.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld64.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld128.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld32.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld64.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld128.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld32.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld64.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld128.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-asm-aarch64-neondot-ld64.S
src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53.S
src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-asm-aarch64-neondot-ld32.S
Expand Down
143 changes: 105 additions & 38 deletions gemm_compiler/aarch64_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

class Aarch64(base_architecture.BaseArchitecture):

def __init__(self):
self.decrement = 4
self.unroll_factor = 1

def astride_register(self):
return 'x4'

Expand All @@ -24,7 +28,15 @@ def cm_stride_register(self):
return 'x7'

def am_registers(self):
return [self.a_ptr_register()] + ['x9', 'x10', 'x11', 'x12', 'x21', 'x22']
return [self.a_ptr_register()] + [
'x9',
'x10',
'x11',
'x12',
'x21',
'x22',
'x25',
]

def a_ptr_register(self):
return 'x3'
Expand All @@ -33,7 +45,15 @@ def c_ptr_register(self):
return 'x6'

def cm_registers(self):
return [self.c_ptr_register()] + ['x13', 'x14', 'x15', 'x19', 'x23', 'x24']
return [self.c_ptr_register()] + [
'x13',
'x14',
'x15',
'x19',
'x23',
'x24',
'x26',
]

def w_ptr_register(self):
return 'x5'
Expand Down Expand Up @@ -102,7 +122,7 @@ def register_map_dword(self, reg):
return map[reg]

def function_name(self, M, N, isa):
return f'xnn_f32_gemm_minmax_ukernel_{M}x{N}__asm_aarch64_{isa}_lane\n'
return f'xnn_f32_gemm_minmax_ukernel_{M}x{N}__asm_aarch64_{isa}_ld32\n'

def quantization_params(self):
return ''
Expand All @@ -113,15 +133,16 @@ def header(self, M, N, prefix, isa):
HEADER += 'BEGIN_FUNCTION ' + self.function_name(M, N, isa)
HEADER += """
# Free up GP registers.
stp x19, x20, [sp, -48]
stp x21, x22, [sp, -32]
stp x23, x24, [sp, -16]
stp x19, x20, [sp, -64]
stp x21, x22, [sp, -48]
stp x23, x24, [sp, -32]
stp x25, x26, [sp, -16]
# Preserve callee saved q8-q15 registers.
stp q8, q9, [sp, -176]
stp q10, q11, [sp, -144]
stp q12, q13, [sp, -112]
stp q14, q15, [sp, -80]
stp d8, d9, [sp, -128]
stp d10, d11, [sp, -112]
stp d12, d13, [sp, -96]
stp d14, d15, [sp, -80]
# Load params.
ldr x13, [sp, 8]
Expand All @@ -137,26 +158,11 @@ def jump_to_label(self, label):
def read_a_registers(self, M):
return ''

def inner_loop(self, M, N):
def do_loop(self, M, N, i):
N_COUNT = N // self.n_step()
asm_string = '\ninner_loop:\n'
if 'before' in self.input_asm():
asm_string += self.input_asm()['before']
for mr in range(0, M):
for l in self.input_asm()['loop']:
asm_string += l.format(
AM_ptr=self.am_registers()[mr],
AM=self.a_registers(mr),
a_offset=self.k_register(),
)
if 'after' in self.input_asm():
asm_string += self.input_asm()['after']

# weights
if 'before' in self.weights_asm():
asm_string += self.weights_asm()['before']
asm_string = ''
for l in self.weights_asm()['loop_2']:
for nr in range(0, N_COUNT, 2):
for nr in range(0, N_COUNT - 1, 2):
asm_string += l.format(
W_ptr=self.w_ptr_register(),
W=self.w_registers()[nr],
Expand Down Expand Up @@ -184,9 +190,69 @@ def inner_loop(self, M, N):
W=self.w_registers()[nr],
A=self.a_registers(mr),
ACC=self.acc_registers()[M * nr + mr],
POS=i,
)
return asm_string

def inner_loop(self, M, N):
asm_string = ''
if self.unroll_factor > 1:
DECREMENT = self.unroll_factor * 4
k_register = self.k_register()
asm_string += f'\n# Are there at least {DECREMENT} bytes?\n'
asm_string += f'cmp {k_register}, {DECREMENT}\n'
asm_string += f'blt inner_loop_tail\n'
asm_string += f'sub {k_register}, {k_register}, {DECREMENT}\n'

asm_string += '\ninner_loop:\n'
decrement = 4 * self.unroll_factor
if 'before' in self.input_asm():
asm_string += self.input_asm()['before']
for mr in range(0, M):
for l in self.input_asm()['loop']:
asm_string += l.format(
AM_ptr=self.am_registers()[mr],
AM=self.a_registers(mr),
a_offset=self.k_register(),
)
if 'after' in self.input_asm():
asm_string += self.input_asm()['after']

# weights
if 'before' in self.weights_asm():
asm_string += self.weights_asm()['before']
inner_loop_label = 'inner_loop'
if self.unroll_factor > 1:
for u in range(self.unroll_factor):
asm_string += self.do_loop(M, N, u)
# loop counter
asm_string += self.cmp_k_and_jump_if_less(
label=inner_loop_label, decrement=decrement, cond='bhs'
)

asm_string += f"""
add x20, x20, {decrement}
cmp x20, 4
blt inner_loop_end
\ninner_loop_tail:\n"""
inner_loop_label = 'inner_loop_tail'

for mr in range(0, M):
for l in self.base_input_asm()['loop']:
asm_string += l.format(
AM_ptr=self.am_registers()[mr],
AM=self.a_registers(mr),
a_offset=self.k_register(),
)
asm_string += self.do_loop(M, N, 0)
# loop counter
asm_string += self.cmp_k_and_jump_if_less(
label=inner_loop_label, decrement=4, cond='bne'
)
asm_string += '\n'

return asm_string

def outer_loop_prepare(self, M, N):
return ''

Expand Down Expand Up @@ -280,25 +346,26 @@ def initialize_k_register(self, reg):
kc_register = self.kc_register()
return f'mov {reg}, {kc_register}\n'

def cmp_k_and_jump_if_less(self, label):
def cmp_k_and_jump_if_less(self, label, decrement, cond):
kc_register = self.kc_register()
k_register = self.k_register()
return f"""subs {k_register}, {k_register}, 4
bne {label}\n"""
return f"""subs {k_register}, {k_register}, {decrement}
{cond} {label}\n"""

def epilogue(self, M, N, isa):
restore_stack = """
return:
# Restore the callee saved GP registers.
ldp x19, x20, [sp, -48]
ldp x21, x22, [sp, -32]
ldp x23, x24, [sp, -16]
ldp x19, x20, [sp, -64]
ldp x21, x22, [sp, -48]
ldp x23, x24, [sp, -32]
ldp x25, x26, [sp, -16]
# Restore callee saved q8-q15 registers.
ldp q8, q9, [sp, -176]
ldp q10, q11, [sp, -144]
ldp q12, q13, [sp, -112]
ldp q14, q15, [sp, -80]
ldp d8, d9, [sp, -128]
ldp d10, d11, [sp, -112]
ldp d12, d13, [sp, -96]
ldp d14, d15, [sp, -80]
ret
END_FUNCTION {function_name}""".format(
M=M, N=N, function_name=isa.function_name(M, N, isa.isa())
Expand Down
6 changes: 2 additions & 4 deletions gemm_compiler/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,11 @@ def generate_gemm_microkernel(
# inner loop
asm_string += isa.inner_loop(M, N)

# loop counter
asm_string += isa.cmp_k_and_jump_if_less(label='inner_loop')

asm_string += 'inner_loop_end:\n'
asm_string += isa.dequantize(M=M, N=num_horizontal_registers, W=w_ptr_reg)

# min/max clamping
asm_string += '# Min/max clamping..\n'
asm_string += '# Min/max clamping.\n'
for nr in range(0, num_horizontal_registers):
for mr in range(0, M):
asm_string += isa.clamp_min(
Expand Down
22 changes: 18 additions & 4 deletions gemm_compiler/generate_f32_gemm_microkernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,28 @@ def generate_f32_gemm_microkernels():
),
)

for nr in range(8, 17, 8):
for unroll in {1, 2, 4}:
decrement = 32 * unroll
for mr in range(1, 6):
generate.generate_gemm_microkernel(
M=mr,
N=nr,
isa=neonfma_template.NeonFma(),
N=16,
isa=neonfma_template.NeonFmaUnolled(unroll),
output_file=os.path.join(
output_base,
f'f32-gemm-{mr}x16-minmax-asm-aarch64-neonfma-ld{decrement}.S',
),
)

for unroll in {1, 2, 4}:
decrement = 32 * unroll
for mr in range(1, 9):
generate.generate_gemm_microkernel(
M=mr,
N=8,
isa=neonfma_template.NeonFmaUnolled(unroll),
output_file=os.path.join(
output_base,
f'f32-gemm-{mr}x{nr}-minmax-asm-aarch64-neonfma-ld32.S',
f'f32-gemm-{mr}x8-minmax-asm-aarch64-neonfma-ld{decrement}-2.S',
),
)
Loading

0 comments on commit 68cb8e2

Please sign in to comment.