From dd0ed28baeeb5a6421863566ffc79f4850b4e541 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Sun, 1 Dec 2024 23:12:04 -0500 Subject: [PATCH 1/5] [Packing Refactor] Move all Blockwise Packing to pack_weights_and_bias --- src/configs/gemm-config.c | 9 +- src/operators/batch-matrix-multiply-nc.c | 10 +- src/operators/convolution-nhwc.c | 1 + src/operators/fully-connected-nc.c | 98 ++------------- src/reference/packing.cc | 152 ++++++++++++++++++++--- src/xnnpack/config-types.h | 2 - src/xnnpack/microfnptr.h | 2 + src/xnnpack/pack.h | 42 ++++++- 8 files changed, 199 insertions(+), 117 deletions(-) diff --git a/src/configs/gemm-config.c b/src/configs/gemm-config.c index d1e754751d3..d5236d1f5d4 100644 --- a/src/configs/gemm-config.c +++ b/src/configs/gemm-config.c @@ -1484,7 +1484,8 @@ static void init_qd8_f16_qc4w_gemm_config(void) { } static void init_qd8_f16_qb4w_gemm_config(void) { - qd8_f16_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w; + qd8_f16_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_qb4_weights_and_biases; + qd8_f16_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_weights_and_biases; #if XNN_ARCH_ARM && XNN_ENABLE_ARM_FP16_VECTOR && XNN_ENABLE_ARM_FP16_SCALAR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); @@ -1744,7 +1745,8 @@ static void init_qp8_f32_qb4w_gemm_config(void) { } static void init_qdu8_f32_qb4w_gemm_config(void) { - qdu8_f32_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w; + qdu8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_qb4_weights_and_biases; + qdu8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_weights_and_biases; #if XNN_ARCH_X86 || XNN_ARCH_X86_64 const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); @@ -1777,7 +1779,8 @@ static void init_qdu8_f32_qb4w_gemm_config(void) { } static void init_qd8_f32_qb4w_gemm_config(void) { - qd8_f32_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w; + qd8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_qb4_weights_and_biases; + qd8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_weights_and_biases; #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); diff --git a/src/operators/batch-matrix-multiply-nc.c b/src/operators/batch-matrix-multiply-nc.c index a18f21ff3a0..d943c8d32ec 100644 --- a/src/operators/batch-matrix-multiply-nc.c +++ b/src/operators/batch-matrix-multiply-nc.c @@ -180,7 +180,9 @@ enum xnn_status xnn_create_batch_matrix_multiply_nc_f32_const_weights( // Pack the weights. if (gemm_config->pack_weights_and_biases) { gemm_config->pack_weights_and_biases(flags, gemm_config, k, n, - /*groups=*/batch_size_b, k_stride, + /*groups=*/batch_size_b, + /*unused_block_size=*/0, + /*kstride=*/k_stride, /*accumulator_init=*/NULL, /*weights=*/data_b, /*int_extra_data0_fn=*/NULL, @@ -313,7 +315,7 @@ enum xnn_status create_batch_matrix_multiply_nc_qx8_f32_qc8w( const size_t weights_stride = gemm_config->packed_stride_weights_and_biases ? gemm_config->packed_stride_weights_and_biases( - gemm_config, k, k_stride, extra_bytes) + gemm_config, k,/*unused_blocksize=*/0, k_stride, extra_bytes) : (k_stride << XNN_LOG2_SIZEOF_INT8_T) + extra_bytes + sizeof(int32_t); assert(weights_stride == (k_stride << XNN_LOG2_SIZEOF_INT8_T) + @@ -345,7 +347,9 @@ enum xnn_status create_batch_matrix_multiply_nc_qx8_f32_qc8w( batch_matrix_multiply_op->flags ^ XNN_FLAG_TRANSPOSE_WEIGHTS, gemm_config, /*input_channels=*/k, /*output_channels=*/n, - /*groups=*/batch_size_b, k_stride, + /*groups=*/batch_size_b, + /*unused_block_size=*/0, + /*k_stride=*/k_stride, /*accumulator_init=*/NULL, /*weights=*/data_b, /*int_extra_data0_fn=*/ diff --git a/src/operators/convolution-nhwc.c b/src/operators/convolution-nhwc.c index bfc0be24eb7..3f3a22b3720 100644 --- a/src/operators/convolution-nhwc.c +++ b/src/operators/convolution-nhwc.c @@ -372,6 +372,7 @@ static enum xnn_status create_gemm_or_igemm( gemm_config->pack_weights_and_biases( flags, gemm_config, group_input_channels, group_output_channels, groups, + /*unused_block_size*/0, k_stride, /*accumulator_init=*/bias, /*weights=*/kernel, diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c index 30c7a9c7f7e..34b20c2ff0a 100644 --- a/src/operators/fully-connected-nc.c +++ b/src/operators/fully-connected-nc.c @@ -44,7 +44,6 @@ static enum xnn_status create_fully_connected_nc( const void* bias, uint32_t flags, size_t block_size, - size_t extra_bl_bytes, const uint16_t* blockwise_kernel_scale_params, uint32_t log2_input_element_size, uint32_t log2_filter_element_size, @@ -52,7 +51,6 @@ static enum xnn_status create_fully_connected_nc( uint32_t bias_element_size, xnn_packw_gemm_gio_ukernel_fn pack_gemm_gio_w, xnn_packw_gemm_goi_ukernel_fn pack_gemm_goi_w, - xnn_packw_gemm_goi_bl_ukernel_fn pack_gemm_goi_bl_w, const void* packing_params, int packed_weights_padding_byte, size_t extra_weights_bytes, @@ -155,7 +153,7 @@ static enum xnn_status create_fully_connected_nc( const size_t weights_stride = gemm_config->packed_stride_weights_and_biases ? gemm_config->packed_stride_weights_and_biases( - gemm_config, input_channels, block_wise ? block_size : k_stride, extra_weights_bytes) + gemm_config, input_channels, block_size, k_stride, extra_weights_bytes) : (k_stride << log2_filter_element_size) + bias_element_size + extra_weights_bytes + block_scale_bytes; const size_t packed_weights_size = n_stride * weights_stride; @@ -192,7 +190,8 @@ static enum xnn_status create_fully_connected_nc( gemm_config->pack_weights_and_biases( flags, gemm_config, input_channels, output_channels, /*groups=*/1, - block_wise ? block_size : k_stride, + /*block_wise=*/block_size, + /*kstride=*/k_stride, /*accumulator_init=*/bias, /*weights=*/kernel, /*int_extra_data0_fn=*/(xnn_init_scale_params_fn)init_scale_params, @@ -204,16 +203,6 @@ static enum xnn_status create_fully_connected_nc( /*extra_data1_size=*/init_kernel_scale_params != NULL ? sizeof(float) : 0, /*packed_weights_ptr=*/weights_ptr, packing_params); - - if (block_wise && bias != NULL) { - void* weights_start = (void*) ((uintptr_t) weights_ptr + - gemm_config->nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2))); - weights_start = (void*) ((uintptr_t) weights_ptr + gemm_config->nr * (weights_stride - sizeof(float))) ; - xnn_init_qs8_qc8w_scale_fp32_params( - output_channels, gemm_config->nr, gemm_config->nr, - gemm_config->nr * weights_stride, gemm_config->nr * weights_stride, 0, - bias, weights_start); - } } else { if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { pack_gemm_gio_w( @@ -225,24 +214,13 @@ static enum xnn_status create_fully_connected_nc( gemm_config->nr * extra_weights_bytes, packing_params); } else { - if (block_wise) { - pack_gemm_goi_bl_w( - /*groups=*/1, output_channels, input_channels, - nr, kr, sr, block_size, - kernel, /*bias=*/NULL, /*scale=*/blockwise_kernel_scale_params, - weights_ptr, - gemm_config->nr * extra_bl_bytes, - gemm_config->nr * extra_weights_bytes, - packing_params); - } else { - pack_gemm_goi_w( - /*groups=*/1, output_channels, input_channels, - nr, kr, sr, - kernel, bias, /*scale=*/NULL, - weights_ptr, - gemm_config->nr * extra_weights_bytes, - packing_params); - } + pack_gemm_goi_w( + /*groups=*/1, output_channels, input_channels, + nr, kr, sr, + kernel, bias, /*scale=*/NULL, + weights_ptr, + gemm_config->nr * extra_weights_bytes, + packing_params); } if (kernel_scale_params != NULL) { assert(init_kernel_scale_params != NULL); @@ -267,32 +245,6 @@ static enum xnn_status create_fully_connected_nc( gemm_config->nr * weights_stride, gemm_config->nr * weights_stride, 0, scale_params, weights); } - - if (block_wise) { - // Fill in kernel scale. - void* weights_start = (void*) ((uintptr_t) weights_ptr + - gemm_config->nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2))); - - const size_t block_stride = /*weights*/block_size / 2 + sizeof(uint16_t); - - xnn_init_blockwise_scale_bf16_params( - output_channels, gemm_config->nr, gemm_config->nr, - gemm_config->nr * weights_stride, - gemm_config->nr * weights_stride, - /*num_blocks=*/num_blocks, - /*block_stride=*/gemm_config->nr * block_stride, - 0, - (const xnn_bfloat16*)blockwise_kernel_scale_params, weights_start); - - // Fill in bias. - if (bias != NULL) { - weights_start = (void*) ((uintptr_t) weights_ptr + gemm_config->nr * (weights_stride - sizeof(float))) ; - xnn_init_qs8_qc8w_scale_fp32_params( - output_channels, gemm_config->nr, gemm_config->nr, - gemm_config->nr * weights_stride, gemm_config->nr * weights_stride, 0, - bias, weights_start); - } - } } if (use_weights_cache(fully_connected_op)) { @@ -397,7 +349,6 @@ enum xnn_status xnn_create_fully_connected_nc_f16( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_HALF, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_HALF, @@ -405,7 +356,6 @@ enum xnn_status xnn_create_fully_connected_nc_f16( /*bias_element_size=*/sizeof(uint16_t), pack_gemm_gio_w, pack_gemm_goi_w, - /*pack_gemm_goi_bl_w=*/NULL, /*packing_params=*/NULL, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/0, @@ -494,7 +444,6 @@ enum xnn_status create_fully_connected_nc_qx8_f16_qc4w( input_stride, output_stride, kernel, /*bias=*/NULL, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -502,7 +451,6 @@ enum xnn_status create_fully_connected_nc_qx8_f16_qc4w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float) * 2, @@ -669,7 +617,6 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qb4w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/block_size, - /*extra_bl_bytes=*/sizeof(uint16_t), /*blockwise_kernel_scale_params=*/kernel_scale, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -677,7 +624,6 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qb4w( /*bias_element_size=*/sizeof(float), /*pack_gemm_gio_w,=*/ NULL, /*pack_gemm_goi_w=*/ NULL, - /*pack_gemm_goi_bl_w=*/gemm_config->pack_gemm_goi_bl, /*packing_params=*/&packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -762,7 +708,6 @@ enum xnn_status create_fully_connected_nc_qx8_f32_qc4w( input_stride, output_stride, kernel, /*bias=*/NULL, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -770,7 +715,6 @@ enum xnn_status create_fully_connected_nc_qx8_f32_qc4w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float) * 2, @@ -909,7 +853,6 @@ enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qc4w( input_channels, output_channels, input_stride, output_stride, kernel, /*bias=*/NULL, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -917,7 +860,6 @@ enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qc4w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn)gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn)gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/0, @@ -1031,7 +973,6 @@ enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qb4w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/block_size, - /*extra_bl_bytes=*/sizeof(uint16_t), /*blockwise_kernel_scale_params=*/kernel_scale, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -1039,7 +980,6 @@ enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qb4w( /*bias_element_size=*/sizeof(float), /*pack_gemm_gio_w,=*/ NULL, /*pack_gemm_goi_w=*/ NULL, - /*pack_gemm_goi_bl_w=*/gemm_config->pack_gemm_goi_bl, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/0, @@ -1156,7 +1096,6 @@ enum xnn_status create_fully_connected_nc_qx8_f32_qb4w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/block_size, - /*extra_bl_bytes=*/sizeof(uint16_t), /*blockwise_kernel_scale_params=*/kernel_scale, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -1164,7 +1103,6 @@ enum xnn_status create_fully_connected_nc_qx8_f32_qb4w( /*bias_element_size=*/sizeof(float), /*pack_gemm_gio_w,=*/ NULL, /*pack_gemm_goi_w=*/ NULL, - /*pack_gemm_goi_bl_w=*/gemm_config->pack_gemm_goi_bl, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -1287,7 +1225,6 @@ enum xnn_status create_fully_connected_nc_qdx8_f32_qc8w( input_stride, output_stride, kernel, NULL, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -1295,7 +1232,6 @@ enum xnn_status create_fully_connected_nc_qdx8_f32_qc8w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float) * 2, @@ -1418,7 +1354,6 @@ enum xnn_status create_fully_connected_nc_qx8_f16_qc8w( input_stride, output_stride, kernel, NULL, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -1426,7 +1361,6 @@ enum xnn_status create_fully_connected_nc_qx8_f16_qc8w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float) * 2, @@ -1574,7 +1508,6 @@ enum xnn_status create_fully_connected_nc_f32( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_FLOAT, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_FLOAT, @@ -1582,7 +1515,6 @@ enum xnn_status create_fully_connected_nc_f32( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, /*packing_params=*/NULL, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/0, @@ -1719,7 +1651,6 @@ enum xnn_status xnn_create_fully_connected_nc_f32_qc4w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_FLOAT, // Pass 1 byte even though it is half byte, we handle the division via filter_is_nibble == true. @@ -1728,7 +1659,6 @@ enum xnn_status xnn_create_fully_connected_nc_f32_qc4w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) NULL, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, /*packing_params=*/NULL, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -1812,7 +1742,6 @@ enum xnn_status xnn_create_fully_connected_nc_f32_qc8w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_FLOAT, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -1820,7 +1749,6 @@ enum xnn_status xnn_create_fully_connected_nc_f32_qc8w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, /*packing_params=*/NULL, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -1912,7 +1840,6 @@ enum xnn_status xnn_create_fully_connected_nc_qs8( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -1920,7 +1847,6 @@ enum xnn_status xnn_create_fully_connected_nc_qs8( /*bias_element_size=*/sizeof(int32_t), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -2019,7 +1945,6 @@ enum xnn_status xnn_create_fully_connected_nc_qs8_qc8w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -2027,7 +1952,6 @@ enum xnn_status xnn_create_fully_connected_nc_qs8_qc8w( /*bias_element_size=*/sizeof(int32_t), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -2119,7 +2043,6 @@ enum xnn_status xnn_create_fully_connected_nc_qu8( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -2127,7 +2050,6 @@ enum xnn_status xnn_create_fully_connected_nc_qu8( /*bias_element_size=*/sizeof(int32_t), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/kernel_zero_point, /*extra_weights_bytes=*/0, diff --git a/src/reference/packing.cc b/src/reference/packing.cc index 1bd7912db06..c39f2aa412c 100644 --- a/src/reference/packing.cc +++ b/src/reference/packing.cc @@ -19,6 +19,7 @@ #include "xnnpack/math.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" +#include "xnnpack/microparams-init.h" #include "xnnpack/pack.h" #include "xnnpack/unaligned.h" @@ -1408,8 +1409,8 @@ void pack_weights_and_biases(uint32_t flags, // } size_t xnn_packed_stride_qs8_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t k_stride, - size_t extra_bytes) { + const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t unused_block_size, + size_t k_stride, size_t extra_bytes) { const size_t bias_element_size = sizeof(int32_t); const size_t log2_filter_element_size = XNN_LOG2_SIZEOF_INT8_T; return (k_stride << log2_filter_element_size) + bias_element_size + @@ -1418,7 +1419,7 @@ size_t xnn_packed_stride_qs8_weights_and_biases( void xnn_pack_qs8_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, - size_t input_channels, size_t output_channels, size_t groups, + size_t input_channels, size_t output_channels, size_t groups, size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, @@ -1428,7 +1429,7 @@ void xnn_pack_qs8_weights_and_biases( const size_t extra_bytes = extra_data0_element_size + extra_data1_element_size; const size_t weights_stride = xnn_packed_stride_qs8_weights_and_biases( - gemm_config, input_channels, k_stride, extra_bytes); + gemm_config, input_channels, unused_block_size, k_stride, extra_bytes); return pack_weights_and_biases( flags, gemm_config, input_channels, output_channels, groups, weights_stride, (xnn_packw_gemm_gio_ukernel_fn)xnn_pack_qs8_gemm_gio_w, @@ -1439,8 +1440,8 @@ void xnn_pack_qs8_weights_and_biases( } size_t xnn_packed_stride_qs4_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t k_stride, - size_t extra_bytes) { + const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t unused_block_size, + size_t k_stride, size_t extra_bytes) { const size_t bias_element_size = sizeof(int32_t); const size_t log2_filter_element_size = XNN_LOG2_SIZEOF_INT8_T; return (k_stride << log2_filter_element_size) + bias_element_size + @@ -1450,7 +1451,7 @@ size_t xnn_packed_stride_qs4_weights_and_biases( void xnn_pack_qs4_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, + size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1458,7 +1459,7 @@ void xnn_pack_qs4_weights_and_biases( const void* params) { const size_t extra_bytes = extra_data0_element_size + extra_data1_element_size; const size_t weights_stride = xnn_packed_stride_qs8_weights_and_biases( - gemm_config, input_channels, k_stride, extra_bytes); + gemm_config, input_channels, unused_block_size, k_stride, extra_bytes); return pack_weights_and_biases( flags, gemm_config, input_channels, output_channels, groups, weights_stride, @@ -1469,9 +1470,111 @@ void xnn_pack_qs4_weights_and_biases( extra_data1_element_size, packed_weights_ptr, extra_bytes, params); } +size_t xnn_packed_stride_qb4_weights_and_biases( + const struct xnn_gemm_config* gemm_config, size_t k, size_t block_size, + size_t k_stride, size_t extra_bytes) { + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + const uint32_t nr = gemm_config->nr; + const size_t planes = gemm_config->planes; + + size_t input_channels = round_up_po2(k, planes); + + size_t block_scale_bytes = 0; + size_t num_blocks = 0; + const bool block_wise = (block_size != 0); + if (block_wise) { + num_blocks = input_channels / block_size; + block_scale_bytes += num_blocks * sizeof(uint16_t); + } + + const size_t bias_element_size = sizeof(int32_t); + const size_t log2_filter_element_size = XNN_LOG2_SIZEOF_INT8_T; + return (k_stride << log2_filter_element_size) + bias_element_size + + extra_bytes + block_scale_bytes; +} + +void xnn_pack_qb4_weights_and_biases( + uint32_t flags, const struct xnn_gemm_config* gemm_config, + size_t input_channels, size_t output_channels, size_t groups, + size_t block_size, size_t k_stride, const void* accumulator_init, const void* weights, + xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, + size_t extra_data0_element_size, + xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, + size_t extra_data1_element_size, void* packed_weights_ptr, + const void* params) { + + const uint32_t nr = gemm_config->nr; + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + const size_t planes = gemm_config->planes; + + const size_t extra_bytes_bl = sizeof(uint16_t); + const size_t extra_bytes_n = sizeof(uint32_t); + if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { + xnn_pack_qs8_qb4w_gemm_gio_w( + /*g=*/groups, + /*nc=*/output_channels, + /*kc=*/input_channels, + /*nr=*/nr, + /*kr=*/kr, + /*sr=*/sr, + /*k_stride=*/k_stride, + /*bl=*/block_size, + /*k=*/(const uint8_t*)weights, + /*bias=*/NULL, + /*scale=*/(const xnn_bfloat16*)extra_data1, + /*packed_weights=*/packed_weights_ptr, + /*extra_bytes_bl=*/nr * extra_bytes_bl, + /*extra_bytes_n=*/nr * extra_bytes_n, + /*params*/(const struct xnn_qs8_qc4w_packing_params *)params); + } else { + xnn_pack_qs8_qb4w_gemm_goi_w( + /*g=*/groups, + /*nc=*/output_channels, + /*kc=*/input_channels, + /*nr=*/nr, + /*kr=*/kr, + /*sr=*/sr, + /*bl=*/block_size, + /*k=*/(const uint8_t*)weights, + /*bias=*/NULL, + /*scale=*/(const xnn_bfloat16*)extra_data1, + /*packed_weights=*/packed_weights_ptr, + /*extra_bytes_bl=*/nr * extra_bytes_bl, + /*extra_bytes_n=*/nr * extra_bytes_n, + /*params*/(const struct xnn_qs8_qc4w_packing_params *)params); + } + + // fill in kernel scales + const size_t num_blocks = input_channels / block_size; + const size_t weights_stride = xnn_packed_stride_qb4_weights_and_biases(gemm_config, input_channels, block_size, k_stride, extra_bytes_n); + void* weights_start = (void*) ((uintptr_t) packed_weights_ptr + + nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2))); + + const size_t block_stride = /*weights*/block_size / 2 + sizeof(uint16_t); + xnn_init_blockwise_scale_bf16_params( + output_channels, nr, nr, + nr * weights_stride, + nr * weights_stride, + /*num_blocks=*/num_blocks, + /*block_stride=*/gemm_config->nr * block_stride, + 0, + (const xnn_bfloat16*)extra_data1, weights_start); + + // fill in bias if not null + if (accumulator_init != nullptr) { + weights_start = (void*) ((uintptr_t) packed_weights_ptr + gemm_config->nr * (weights_stride - sizeof(float))) ; + xnn_init_qs8_qc8w_scale_fp32_params( + output_channels, gemm_config->nr, gemm_config->nr, + gemm_config->nr * weights_stride, gemm_config->nr * weights_stride, 0, + (const float*)accumulator_init, weights_start); + } +} + size_t xnn_packed_stride_qu8_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t k_stride, - size_t extra_bytes) { + const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t unused_block_size, + size_t k_stride, size_t extra_bytes) { const size_t bias_element_size = sizeof(int32_t); const size_t log2_filter_element_size = XNN_LOG2_SIZEOF_INT8_T; return (k_stride << log2_filter_element_size) + bias_element_size + @@ -1481,7 +1584,7 @@ size_t xnn_packed_stride_qu8_weights_and_biases( void xnn_pack_qu8_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, + size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1490,7 +1593,7 @@ void xnn_pack_qu8_weights_and_biases( const size_t extra_bytes = extra_data0_element_size + extra_data1_element_size; const size_t weights_stride = xnn_packed_stride_qs8_weights_and_biases( - gemm_config, input_channels, k_stride, extra_bytes); + gemm_config, input_channels, unused_block_size, k_stride, extra_bytes); return pack_weights_and_biases( flags, gemm_config, input_channels, output_channels, groups, weights_stride, (xnn_packw_gemm_gio_ukernel_fn)xnn_pack_qu8_gemm_gio_w, @@ -1502,7 +1605,7 @@ void xnn_pack_qu8_weights_and_biases( #if XNN_ENABLE_KLEIDIAI size_t xnn_packed_stride_kai_qs4_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, + const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_block_size, size_t unused_k_stride, size_t extra_bytes) { const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; @@ -1513,7 +1616,7 @@ size_t xnn_packed_stride_kai_qs4_weights_and_biases( void xnn_pack_kai_qs4_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, + size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1557,7 +1660,7 @@ void xnn_pack_kai_qs4_weights_and_biases( } size_t xnn_packed_stride_kai_f32_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, + const struct xnn_gemm_config* gemm_config, size_t unused_block_size, size_t k, size_t unused_k_stride, size_t extra_bytes) { size_t ret_val = kai_get_rhs_packed_stride_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(k) / @@ -1576,7 +1679,7 @@ void transpose_weights(const float* in, float* out, size_t height, size_t width) void xnn_pack_kai_f32_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, + size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1625,7 +1728,7 @@ void xnn_pack_kai_f32_weights_and_biases( size_t xnn_packed_stride_kai_qb4_weights_and_biases( const struct xnn_gemm_config* gemm_config, size_t k, size_t block_size, - size_t extra_bytes) { + size_t unused_k_stride, size_t extra_bytes) { const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; const uint32_t nr = gemm_config->nr; @@ -1643,7 +1746,7 @@ size_t xnn_packed_stride_kai_qb4_weights_and_biases( void xnn_pack_kai_qb4_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t block_size, const void* accumulator_init, const void* weights, + size_t block_size, size_t unused_k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1693,6 +1796,19 @@ void xnn_pack_kai_qb4_weights_and_biases( /*extra_bytes=*/0, &kai_params); } + + // init bias + const size_t weights_stride = xnn_packed_stride_kai_qb4_weights_and_biases( + gemm_config, input_channels, block_size, unused_k_stride, 0); + if (accumulator_init != NULL) { + void* weights_start = (void*) ((uintptr_t) packed_weights_ptr + + nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2))); + weights_start = (void*) ((uintptr_t) packed_weights_ptr + nr * (weights_stride - sizeof(float))) ; + xnn_init_qs8_qc8w_scale_fp32_params( + output_channels, nr, nr, + nr * weights_stride, nr * weights_stride, 0, + (const float*)accumulator_init, weights_start); + } } #endif // XNN_ENABLE_KLEIDIAI diff --git a/src/xnnpack/config-types.h b/src/xnnpack/config-types.h index 7262d376ef3..bcd1bfb2f80 100644 --- a/src/xnnpack/config-types.h +++ b/src/xnnpack/config-types.h @@ -193,8 +193,6 @@ struct xnn_gemm_config { xnn_packw_gemm_gio_ukernel_fn pack_gemm_gio; // Deprecated. Use pack_weights_and_biases instead. xnn_packw_gemm_goi_ukernel_fn pack_gemm_goi; - // TODO(b/346765736): Use pack_weights_and_biases instead. - xnn_packw_gemm_goi_bl_ukernel_fn pack_gemm_goi_bl; xnn_pack_conv_goki_w_fn pack_igemm_goki; xnn_pack_conv_kgo_w_fn pack_igemm_kgo; xnn_pack_deconv_goki_w_fn pack_deconv_goki; diff --git a/src/xnnpack/microfnptr.h b/src/xnnpack/microfnptr.h index edd1a8aae03..e6df9020b14 100644 --- a/src/xnnpack/microfnptr.h +++ b/src/xnnpack/microfnptr.h @@ -2229,6 +2229,7 @@ typedef void (*xnn_pack_weights_and_biases_fn)( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t block_size, // // We tile packing by output channels, in GIO layout, the k (row) index // needs to be able to skip by the actual number of output channels, and not // just the argument nc. E.g. if weights is 1x3x5, and nr is 2, we tile the @@ -2255,6 +2256,7 @@ typedef void (*xnn_pack_weights_and_biases_fn)( typedef size_t (*xnn_packed_stride_weights_and_biases_fn)( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t block_size, // size_t k_stride, // size_t extra_bytes); diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h index 8850d370249..c58617bec60 100644 --- a/src/xnnpack/pack.h +++ b/src/xnnpack/pack.h @@ -387,6 +387,7 @@ XNN_INTERNAL void xnn_pack_qs8_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -402,6 +403,7 @@ XNN_INTERNAL void xnn_pack_qs8_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_qs8_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t block_size, // size_t k_stride, // size_t extra_bytes); @@ -412,6 +414,7 @@ XNN_INTERNAL void xnn_pack_qs4_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -427,6 +430,33 @@ XNN_INTERNAL void xnn_pack_qs4_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_qs4_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t block_size, // + size_t k_stride, // + size_t extra_bytes); + +XNN_INTERNAL void xnn_pack_qb4_weights_and_biases( + uint32_t flags, // + const struct xnn_gemm_config* gemm_config, // + size_t input_channels, // + size_t output_channels, // + size_t groups, // + size_t block_size, // + size_t k_stride, // + const void* accumulator_init, // + const void* weights, // + xnn_init_scale_params_fn init_extra_data0_fn, // + const void* extra_data0, // + size_t extra_data0_element_size, // + xnn_init_scale_params_fn init_extra_data1_fn, // + const void* extra_data1, // + size_t extra_data1_element_size, // + void* packed_weights_ptr, // + const void* params); + +XNN_INTERNAL size_t xnn_packed_stride_qb4_weights_and_biases( + const struct xnn_gemm_config* gemm_config, // + size_t k, // + size_t block_size, // size_t k_stride, // size_t extra_bytes); @@ -436,6 +466,7 @@ XNN_INTERNAL void xnn_pack_qu8_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -451,6 +482,7 @@ XNN_INTERNAL void xnn_pack_qu8_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_qu8_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t block_size, // size_t k_stride, // size_t extra_bytes); @@ -461,6 +493,7 @@ XNN_INTERNAL void xnn_pack_kai_qs4_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -476,17 +509,18 @@ XNN_INTERNAL void xnn_pack_kai_qs4_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_kai_qs4_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t block_size, // size_t k_stride, // size_t extra_bytes); size_t xnn_packed_stride_kai_f32_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, - size_t extra_bytes); + const struct xnn_gemm_config* gemm_config, size_t k, size_t block_size, + size_t unused_k_stride, size_t extra_bytes); void xnn_pack_kai_f32_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, + size_t block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -500,6 +534,7 @@ XNN_INTERNAL void xnn_pack_kai_qb4_weights_and_biases( size_t output_channels, // size_t groups, // size_t block_size, // + size_t k_stride, // const void* accumulator_init, // const void* weights, // xnn_init_scale_params_fn init_extra_data0_fn, // @@ -515,6 +550,7 @@ XNN_INTERNAL size_t xnn_packed_stride_kai_qb4_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // size_t block_size, // + size_t k_stride, // size_t extra_bytes); #endif // XNN_ENABLE_KLEIDIAI From 8f768614382e107e404bb5635dda7d7e0ee52e0f Mon Sep 17 00:00:00 2001 From: Max Ren Date: Thu, 5 Dec 2024 15:29:04 -0500 Subject: [PATCH 2/5] add back pack_gemm_goi_bl ukernel --- src/xnnpack/config-types.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/xnnpack/config-types.h b/src/xnnpack/config-types.h index bcd1bfb2f80..7262d376ef3 100644 --- a/src/xnnpack/config-types.h +++ b/src/xnnpack/config-types.h @@ -193,6 +193,8 @@ struct xnn_gemm_config { xnn_packw_gemm_gio_ukernel_fn pack_gemm_gio; // Deprecated. Use pack_weights_and_biases instead. xnn_packw_gemm_goi_ukernel_fn pack_gemm_goi; + // TODO(b/346765736): Use pack_weights_and_biases instead. + xnn_packw_gemm_goi_bl_ukernel_fn pack_gemm_goi_bl; xnn_pack_conv_goki_w_fn pack_igemm_goki; xnn_pack_conv_kgo_w_fn pack_igemm_kgo; xnn_pack_deconv_goki_w_fn pack_deconv_goki; From 35b59d9b09def6d422c017c26d75951d8c183b8f Mon Sep 17 00:00:00 2001 From: Max Ren Date: Fri, 15 Nov 2024 16:12:12 -0800 Subject: [PATCH 3/5] [Fast Packing] Add x16c8 and x16c4 packing ukernels for qb4 --- CMakeLists.txt | 1 + cmake/gen/scalar_microkernels.cmake | 6 +- gen/scalar_microkernels.bzl | 6 +- scripts/build-local.sh | 2 +- scripts/generate-qb4-packw.sh | 5 + .../gen/qb4-packw-x16c4-gemm-goi-scalar.c | 1238 +++++++++++ .../gen/qb4-packw-x16c8-gemm-goi-scalar.c | 1858 +++++++++++++++++ src/qb4-packw/kr-scalar.c.in | 200 ++ src/qb4-packw/qb4-packw.h | 9 + src/reference/packing.cc | 4 + src/xnnpack/microfnptr.h | 16 + src/xnnpack/microparams.h | 1 - src/xnnpack/packw.h | 22 + test/packw-microkernel-tester.h | 138 ++ test/qb4-packw.cc | 148 ++ 15 files changed, 3648 insertions(+), 6 deletions(-) create mode 100644 scripts/generate-qb4-packw.sh create mode 100644 src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c create mode 100644 src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c create mode 100644 src/qb4-packw/kr-scalar.c.in create mode 100644 src/qb4-packw/qb4-packw.h create mode 100644 test/qb4-packw.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 856a47d1a66..1b1688ed70b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1469,6 +1469,7 @@ IF(XNNPACK_BUILD_TESTS) x8-packw qs8-packw qs8-qc4w-packw + qb4-packw x8-zip xN-transpose xx-fill diff --git a/cmake/gen/scalar_microkernels.cmake b/cmake/gen/scalar_microkernels.cmake index a07d6791fc0..be694e9477e 100644 --- a/cmake/gen/scalar_microkernels.cmake +++ b/cmake/gen/scalar_microkernels.cmake @@ -157,8 +157,6 @@ SET(PROD_SCALAR_MICROKERNEL_SRCS src/qs8-f32-vcvt/gen/qs8-f32-vcvt-scalar-u4.c src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-scalar.c - src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-scalar.c - src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-scalar.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p1c-minmax-fp32-scalar-fmagic.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p2c-minmax-fp32-scalar-imagic.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p2c-minmax-fp32-scalar-lrintf.c @@ -543,6 +541,8 @@ SET(NON_PROD_SCALAR_MICROKERNEL_SRCS src/f32-vsigmoid/gen/f32-vsigmoid-scalar-rr2-p5-div-u4.c src/f32-vsqrt/gen/f32-vsqrt-scalar-sqrt-u2.c src/f32-vsqrt/gen/f32-vsqrt-scalar-sqrt-u4.c + src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c + src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c @@ -621,6 +621,8 @@ SET(NON_PROD_SCALAR_MICROKERNEL_SRCS src/qs8-packw/gen/qs8-packw-x32c4-gemm-gio-scalar.c src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c src/qs8-packw/gen/qs8-packw-x64c4-gemm-gio-scalar.c + src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-scalar.c + src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-scalar.c src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x32c8-gemm-goi-scalar.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-4p2c-minmax-fp32-scalar-imagic.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-5f5m5l1c1s1r-minmax-fp32-scalar-fmagic.c diff --git a/gen/scalar_microkernels.bzl b/gen/scalar_microkernels.bzl index 2243ba6160b..fe920fe5c76 100644 --- a/gen/scalar_microkernels.bzl +++ b/gen/scalar_microkernels.bzl @@ -153,8 +153,6 @@ PROD_SCALAR_MICROKERNEL_SRCS = [ "src/qs8-f32-vcvt/gen/qs8-f32-vcvt-scalar-u4.c", "src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c", "src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-scalar.c", - "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-scalar.c", - "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-scalar.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p1c-minmax-fp32-scalar-fmagic.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p2c-minmax-fp32-scalar-imagic.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p2c-minmax-fp32-scalar-lrintf.c", @@ -540,6 +538,8 @@ NON_PROD_SCALAR_MICROKERNEL_SRCS = [ "src/f32-vsigmoid/gen/f32-vsigmoid-scalar-rr2-p5-div-u4.c", "src/f32-vsqrt/gen/f32-vsqrt-scalar-sqrt-u2.c", "src/f32-vsqrt/gen/f32-vsqrt-scalar-sqrt-u4.c", + "src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c", + "src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c", "src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c", "src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c", "src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c", @@ -618,6 +618,8 @@ NON_PROD_SCALAR_MICROKERNEL_SRCS = [ "src/qs8-packw/gen/qs8-packw-x32c4-gemm-gio-scalar.c", "src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c", "src/qs8-packw/gen/qs8-packw-x64c4-gemm-gio-scalar.c", + "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-scalar.c", + "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-scalar.c", "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x32c8-gemm-goi-scalar.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-4p2c-minmax-fp32-scalar-imagic.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-5f5m5l1c1s1r-minmax-fp32-scalar-fmagic.c", diff --git a/scripts/build-local.sh b/scripts/build-local.sh index 8fd10736b36..1fb96e80796 100755 --- a/scripts/build-local.sh +++ b/scripts/build-local.sh @@ -13,7 +13,7 @@ mkdir -p build/local CMAKE_ARGS=() # CMake-level configuration -CMAKE_ARGS+=("-DCMAKE_BUILD_TYPE=Release") +CMAKE_ARGS+=("-DCMAKE_BUILD_TYPE=Debug") CMAKE_ARGS+=("-DCMAKE_POSITION_INDEPENDENT_CODE=ON") # If Ninja is installed, prefer it to Make diff --git a/scripts/generate-qb4-packw.sh b/scripts/generate-qb4-packw.sh new file mode 100644 index 00000000000..2994b040039 --- /dev/null +++ b/scripts/generate-qb4-packw.sh @@ -0,0 +1,5 @@ +# C8 Packing +tools/xngen src/qb4-packw/kr-scalar.c.in -D NR=16 -D KR=8 -D -o src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c + +# C4 Packing +tools/xngen src/qb4-packw/kr-scalar.c.in -D NR=16 -D KR=4 -D -o src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c diff --git a/src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c b/src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c new file mode 100644 index 00000000000..bd5826c37fb --- /dev/null +++ b/src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c @@ -0,0 +1,1238 @@ +// Auto-generated file. Do not edit! +// Template: src/qb4-packw/kr-scalar.c.in +// Generator: tools/xngen +// +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "xnnpack/packw.h" + +void xnn_qb4_packw_gemm_goi_ukernel_x16c4__scalar( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + size_t bl, + const uint8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes_bl, + size_t extra_bytes_n, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 4); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + assert(extra_bytes_bl == nr * sizeof(uint16_t)); + assert(extra_bytes_n == nr * sizeof(float)); + assert(params != NULL); + assert(kc % bl == 0); + size_t num_blocks = kc / bl; + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + const uint32_t izp = (uint32_t) (((const struct xnn_qs8_qc4w_packing_params*) params)->input_zero_point + 0); + + do { + // NC main loop multiple of 16 + const uint8_t* w0 = (const uint8_t*) weights; + const uint16_t* s0 = (const uint16_t*) scale; + size_t n = nc; + for (;n >= 16; n -= 16) { + float* packed_k_scaled_sum = (float*) out; + ((float*) out)[0] = 0; + ((float*) out)[1] = 0; + ((float*) out)[2] = 0; + ((float*) out)[3] = 0; + ((float*) out)[4] = 0; + ((float*) out)[5] = 0; + ((float*) out)[6] = 0; + ((float*) out)[7] = 0; + ((float*) out)[8] = 0; + ((float*) out)[9] = 0; + ((float*) out)[10] = 0; + ((float*) out)[11] = 0; + ((float*) out)[12] = 0; + ((float*) out)[13] = 0; + ((float*) out)[14] = 0; + ((float*) out)[15] = 0; + out += 16 * sizeof(float); + + // KC/2 bytes is KC Nibbles + const uint8_t* w1 = w0 + (kc >> 1); + const uint8_t* w2 = w1 + (kc >> 1); + const uint8_t* w3 = w2 + (kc >> 1); + const uint8_t* w4 = w3 + (kc >> 1); + const uint8_t* w5 = w4 + (kc >> 1); + const uint8_t* w6 = w5 + (kc >> 1); + const uint8_t* w7 = w6 + (kc >> 1); + const uint8_t* w8 = w7 + (kc >> 1); + const uint8_t* w9 = w8 + (kc >> 1); + const uint8_t* w10 = w9 + (kc >> 1); + const uint8_t* w11 = w10 + (kc >> 1); + const uint8_t* w12 = w11 + (kc >> 1); + const uint8_t* w13 = w12 + (kc >> 1); + const uint8_t* w14 = w13 + (kc >> 1); + const uint8_t* w15 = w14 + (kc >> 1); + + // scales + const uint16_t* s1 = s0 + num_blocks; + const uint16_t* s2 = s1 + num_blocks; + const uint16_t* s3 = s2 + num_blocks; + const uint16_t* s4 = s3 + num_blocks; + const uint16_t* s5 = s4 + num_blocks; + const uint16_t* s6 = s5 + num_blocks; + const uint16_t* s7 = s6 + num_blocks; + const uint16_t* s8 = s7 + num_blocks; + const uint16_t* s9 = s8 + num_blocks; + const uint16_t* s10 = s9 + num_blocks; + const uint16_t* s11 = s10 + num_blocks; + const uint16_t* s12 = s11 + num_blocks; + const uint16_t* s13 = s12 + num_blocks; + const uint16_t* s14 = s13 + num_blocks; + const uint16_t* s15 = s14 + num_blocks; + + + size_t kb = kc; + // Process k by blocks (bl) + for (; kb >= bl; kb-=bl) { + // Initialize KSum as subtracting bl zero points (8) + int32_t ksum0 = 0; + int32_t ksum1 = 0; + int32_t ksum2 = 0; + int32_t ksum3 = 0; + int32_t ksum4 = 0; + int32_t ksum5 = 0; + int32_t ksum6 = 0; + int32_t ksum7 = 0; + int32_t ksum8 = 0; + int32_t ksum9 = 0; + int32_t ksum10 = 0; + int32_t ksum11 = 0; + int32_t ksum12 = 0; + int32_t ksum13 = 0; + int32_t ksum14 = 0; + int32_t ksum15 = 0; + size_t k = bl; + for(; k >= 8; k-=8) { + const uint8_t v0x0 = w0[0] & 0xF; + const uint8_t v0x1 = w0[0] >> 4; + const uint8_t v0x2 = w0[1] & 0xF; + const uint8_t v0x3 = w0[1] >> 4; + const uint8_t v0x4 = w0[2] & 0xF; + const uint8_t v0x5 = w0[2] >> 4; + const uint8_t v0x6 = w0[3] & 0xF; + const uint8_t v0x7 = w0[3] >> 4; + w0 += 4; + + ksum0 += (uint32_t) (v0x0); + ksum0 += (uint32_t) (v0x1); + ksum0 += (uint32_t) (v0x2); + ksum0 += (uint32_t) (v0x3); + ksum0 += (uint32_t) (v0x4); + ksum0 += (uint32_t) (v0x5); + ksum0 += (uint32_t) (v0x6); + ksum0 += (uint32_t) (v0x7); + // Subtract 8 zero points (8) + ksum0 -= 64; + + out[0] = (v0x0 | (v0x4 << 4)) ^ 0x88; + out[1] = (v0x1 | (v0x5 << 4)) ^ 0x88; + out[2] = (v0x2 | (v0x6 << 4)) ^ 0x88; + out[3] = (v0x3 | (v0x7 << 4)) ^ 0x88; + const uint8_t v1x0 = w1[0] & 0xF; + const uint8_t v1x1 = w1[0] >> 4; + const uint8_t v1x2 = w1[1] & 0xF; + const uint8_t v1x3 = w1[1] >> 4; + const uint8_t v1x4 = w1[2] & 0xF; + const uint8_t v1x5 = w1[2] >> 4; + const uint8_t v1x6 = w1[3] & 0xF; + const uint8_t v1x7 = w1[3] >> 4; + w1 += 4; + + ksum1 += (uint32_t) (v1x0); + ksum1 += (uint32_t) (v1x1); + ksum1 += (uint32_t) (v1x2); + ksum1 += (uint32_t) (v1x3); + ksum1 += (uint32_t) (v1x4); + ksum1 += (uint32_t) (v1x5); + ksum1 += (uint32_t) (v1x6); + ksum1 += (uint32_t) (v1x7); + // Subtract 8 zero points (8) + ksum1 -= 64; + + out[4] = (v1x0 | (v1x4 << 4)) ^ 0x88; + out[5] = (v1x1 | (v1x5 << 4)) ^ 0x88; + out[6] = (v1x2 | (v1x6 << 4)) ^ 0x88; + out[7] = (v1x3 | (v1x7 << 4)) ^ 0x88; + const uint8_t v2x0 = w2[0] & 0xF; + const uint8_t v2x1 = w2[0] >> 4; + const uint8_t v2x2 = w2[1] & 0xF; + const uint8_t v2x3 = w2[1] >> 4; + const uint8_t v2x4 = w2[2] & 0xF; + const uint8_t v2x5 = w2[2] >> 4; + const uint8_t v2x6 = w2[3] & 0xF; + const uint8_t v2x7 = w2[3] >> 4; + w2 += 4; + + ksum2 += (uint32_t) (v2x0); + ksum2 += (uint32_t) (v2x1); + ksum2 += (uint32_t) (v2x2); + ksum2 += (uint32_t) (v2x3); + ksum2 += (uint32_t) (v2x4); + ksum2 += (uint32_t) (v2x5); + ksum2 += (uint32_t) (v2x6); + ksum2 += (uint32_t) (v2x7); + // Subtract 8 zero points (8) + ksum2 -= 64; + + out[8] = (v2x0 | (v2x4 << 4)) ^ 0x88; + out[9] = (v2x1 | (v2x5 << 4)) ^ 0x88; + out[10] = (v2x2 | (v2x6 << 4)) ^ 0x88; + out[11] = (v2x3 | (v2x7 << 4)) ^ 0x88; + const uint8_t v3x0 = w3[0] & 0xF; + const uint8_t v3x1 = w3[0] >> 4; + const uint8_t v3x2 = w3[1] & 0xF; + const uint8_t v3x3 = w3[1] >> 4; + const uint8_t v3x4 = w3[2] & 0xF; + const uint8_t v3x5 = w3[2] >> 4; + const uint8_t v3x6 = w3[3] & 0xF; + const uint8_t v3x7 = w3[3] >> 4; + w3 += 4; + + ksum3 += (uint32_t) (v3x0); + ksum3 += (uint32_t) (v3x1); + ksum3 += (uint32_t) (v3x2); + ksum3 += (uint32_t) (v3x3); + ksum3 += (uint32_t) (v3x4); + ksum3 += (uint32_t) (v3x5); + ksum3 += (uint32_t) (v3x6); + ksum3 += (uint32_t) (v3x7); + // Subtract 8 zero points (8) + ksum3 -= 64; + + out[12] = (v3x0 | (v3x4 << 4)) ^ 0x88; + out[13] = (v3x1 | (v3x5 << 4)) ^ 0x88; + out[14] = (v3x2 | (v3x6 << 4)) ^ 0x88; + out[15] = (v3x3 | (v3x7 << 4)) ^ 0x88; + const uint8_t v4x0 = w4[0] & 0xF; + const uint8_t v4x1 = w4[0] >> 4; + const uint8_t v4x2 = w4[1] & 0xF; + const uint8_t v4x3 = w4[1] >> 4; + const uint8_t v4x4 = w4[2] & 0xF; + const uint8_t v4x5 = w4[2] >> 4; + const uint8_t v4x6 = w4[3] & 0xF; + const uint8_t v4x7 = w4[3] >> 4; + w4 += 4; + + ksum4 += (uint32_t) (v4x0); + ksum4 += (uint32_t) (v4x1); + ksum4 += (uint32_t) (v4x2); + ksum4 += (uint32_t) (v4x3); + ksum4 += (uint32_t) (v4x4); + ksum4 += (uint32_t) (v4x5); + ksum4 += (uint32_t) (v4x6); + ksum4 += (uint32_t) (v4x7); + // Subtract 8 zero points (8) + ksum4 -= 64; + + out[16] = (v4x0 | (v4x4 << 4)) ^ 0x88; + out[17] = (v4x1 | (v4x5 << 4)) ^ 0x88; + out[18] = (v4x2 | (v4x6 << 4)) ^ 0x88; + out[19] = (v4x3 | (v4x7 << 4)) ^ 0x88; + const uint8_t v5x0 = w5[0] & 0xF; + const uint8_t v5x1 = w5[0] >> 4; + const uint8_t v5x2 = w5[1] & 0xF; + const uint8_t v5x3 = w5[1] >> 4; + const uint8_t v5x4 = w5[2] & 0xF; + const uint8_t v5x5 = w5[2] >> 4; + const uint8_t v5x6 = w5[3] & 0xF; + const uint8_t v5x7 = w5[3] >> 4; + w5 += 4; + + ksum5 += (uint32_t) (v5x0); + ksum5 += (uint32_t) (v5x1); + ksum5 += (uint32_t) (v5x2); + ksum5 += (uint32_t) (v5x3); + ksum5 += (uint32_t) (v5x4); + ksum5 += (uint32_t) (v5x5); + ksum5 += (uint32_t) (v5x6); + ksum5 += (uint32_t) (v5x7); + // Subtract 8 zero points (8) + ksum5 -= 64; + + out[20] = (v5x0 | (v5x4 << 4)) ^ 0x88; + out[21] = (v5x1 | (v5x5 << 4)) ^ 0x88; + out[22] = (v5x2 | (v5x6 << 4)) ^ 0x88; + out[23] = (v5x3 | (v5x7 << 4)) ^ 0x88; + const uint8_t v6x0 = w6[0] & 0xF; + const uint8_t v6x1 = w6[0] >> 4; + const uint8_t v6x2 = w6[1] & 0xF; + const uint8_t v6x3 = w6[1] >> 4; + const uint8_t v6x4 = w6[2] & 0xF; + const uint8_t v6x5 = w6[2] >> 4; + const uint8_t v6x6 = w6[3] & 0xF; + const uint8_t v6x7 = w6[3] >> 4; + w6 += 4; + + ksum6 += (uint32_t) (v6x0); + ksum6 += (uint32_t) (v6x1); + ksum6 += (uint32_t) (v6x2); + ksum6 += (uint32_t) (v6x3); + ksum6 += (uint32_t) (v6x4); + ksum6 += (uint32_t) (v6x5); + ksum6 += (uint32_t) (v6x6); + ksum6 += (uint32_t) (v6x7); + // Subtract 8 zero points (8) + ksum6 -= 64; + + out[24] = (v6x0 | (v6x4 << 4)) ^ 0x88; + out[25] = (v6x1 | (v6x5 << 4)) ^ 0x88; + out[26] = (v6x2 | (v6x6 << 4)) ^ 0x88; + out[27] = (v6x3 | (v6x7 << 4)) ^ 0x88; + const uint8_t v7x0 = w7[0] & 0xF; + const uint8_t v7x1 = w7[0] >> 4; + const uint8_t v7x2 = w7[1] & 0xF; + const uint8_t v7x3 = w7[1] >> 4; + const uint8_t v7x4 = w7[2] & 0xF; + const uint8_t v7x5 = w7[2] >> 4; + const uint8_t v7x6 = w7[3] & 0xF; + const uint8_t v7x7 = w7[3] >> 4; + w7 += 4; + + ksum7 += (uint32_t) (v7x0); + ksum7 += (uint32_t) (v7x1); + ksum7 += (uint32_t) (v7x2); + ksum7 += (uint32_t) (v7x3); + ksum7 += (uint32_t) (v7x4); + ksum7 += (uint32_t) (v7x5); + ksum7 += (uint32_t) (v7x6); + ksum7 += (uint32_t) (v7x7); + // Subtract 8 zero points (8) + ksum7 -= 64; + + out[28] = (v7x0 | (v7x4 << 4)) ^ 0x88; + out[29] = (v7x1 | (v7x5 << 4)) ^ 0x88; + out[30] = (v7x2 | (v7x6 << 4)) ^ 0x88; + out[31] = (v7x3 | (v7x7 << 4)) ^ 0x88; + const uint8_t v8x0 = w8[0] & 0xF; + const uint8_t v8x1 = w8[0] >> 4; + const uint8_t v8x2 = w8[1] & 0xF; + const uint8_t v8x3 = w8[1] >> 4; + const uint8_t v8x4 = w8[2] & 0xF; + const uint8_t v8x5 = w8[2] >> 4; + const uint8_t v8x6 = w8[3] & 0xF; + const uint8_t v8x7 = w8[3] >> 4; + w8 += 4; + + ksum8 += (uint32_t) (v8x0); + ksum8 += (uint32_t) (v8x1); + ksum8 += (uint32_t) (v8x2); + ksum8 += (uint32_t) (v8x3); + ksum8 += (uint32_t) (v8x4); + ksum8 += (uint32_t) (v8x5); + ksum8 += (uint32_t) (v8x6); + ksum8 += (uint32_t) (v8x7); + // Subtract 8 zero points (8) + ksum8 -= 64; + + out[32] = (v8x0 | (v8x4 << 4)) ^ 0x88; + out[33] = (v8x1 | (v8x5 << 4)) ^ 0x88; + out[34] = (v8x2 | (v8x6 << 4)) ^ 0x88; + out[35] = (v8x3 | (v8x7 << 4)) ^ 0x88; + const uint8_t v9x0 = w9[0] & 0xF; + const uint8_t v9x1 = w9[0] >> 4; + const uint8_t v9x2 = w9[1] & 0xF; + const uint8_t v9x3 = w9[1] >> 4; + const uint8_t v9x4 = w9[2] & 0xF; + const uint8_t v9x5 = w9[2] >> 4; + const uint8_t v9x6 = w9[3] & 0xF; + const uint8_t v9x7 = w9[3] >> 4; + w9 += 4; + + ksum9 += (uint32_t) (v9x0); + ksum9 += (uint32_t) (v9x1); + ksum9 += (uint32_t) (v9x2); + ksum9 += (uint32_t) (v9x3); + ksum9 += (uint32_t) (v9x4); + ksum9 += (uint32_t) (v9x5); + ksum9 += (uint32_t) (v9x6); + ksum9 += (uint32_t) (v9x7); + // Subtract 8 zero points (8) + ksum9 -= 64; + + out[36] = (v9x0 | (v9x4 << 4)) ^ 0x88; + out[37] = (v9x1 | (v9x5 << 4)) ^ 0x88; + out[38] = (v9x2 | (v9x6 << 4)) ^ 0x88; + out[39] = (v9x3 | (v9x7 << 4)) ^ 0x88; + const uint8_t v10x0 = w10[0] & 0xF; + const uint8_t v10x1 = w10[0] >> 4; + const uint8_t v10x2 = w10[1] & 0xF; + const uint8_t v10x3 = w10[1] >> 4; + const uint8_t v10x4 = w10[2] & 0xF; + const uint8_t v10x5 = w10[2] >> 4; + const uint8_t v10x6 = w10[3] & 0xF; + const uint8_t v10x7 = w10[3] >> 4; + w10 += 4; + + ksum10 += (uint32_t) (v10x0); + ksum10 += (uint32_t) (v10x1); + ksum10 += (uint32_t) (v10x2); + ksum10 += (uint32_t) (v10x3); + ksum10 += (uint32_t) (v10x4); + ksum10 += (uint32_t) (v10x5); + ksum10 += (uint32_t) (v10x6); + ksum10 += (uint32_t) (v10x7); + // Subtract 8 zero points (8) + ksum10 -= 64; + + out[40] = (v10x0 | (v10x4 << 4)) ^ 0x88; + out[41] = (v10x1 | (v10x5 << 4)) ^ 0x88; + out[42] = (v10x2 | (v10x6 << 4)) ^ 0x88; + out[43] = (v10x3 | (v10x7 << 4)) ^ 0x88; + const uint8_t v11x0 = w11[0] & 0xF; + const uint8_t v11x1 = w11[0] >> 4; + const uint8_t v11x2 = w11[1] & 0xF; + const uint8_t v11x3 = w11[1] >> 4; + const uint8_t v11x4 = w11[2] & 0xF; + const uint8_t v11x5 = w11[2] >> 4; + const uint8_t v11x6 = w11[3] & 0xF; + const uint8_t v11x7 = w11[3] >> 4; + w11 += 4; + + ksum11 += (uint32_t) (v11x0); + ksum11 += (uint32_t) (v11x1); + ksum11 += (uint32_t) (v11x2); + ksum11 += (uint32_t) (v11x3); + ksum11 += (uint32_t) (v11x4); + ksum11 += (uint32_t) (v11x5); + ksum11 += (uint32_t) (v11x6); + ksum11 += (uint32_t) (v11x7); + // Subtract 8 zero points (8) + ksum11 -= 64; + + out[44] = (v11x0 | (v11x4 << 4)) ^ 0x88; + out[45] = (v11x1 | (v11x5 << 4)) ^ 0x88; + out[46] = (v11x2 | (v11x6 << 4)) ^ 0x88; + out[47] = (v11x3 | (v11x7 << 4)) ^ 0x88; + const uint8_t v12x0 = w12[0] & 0xF; + const uint8_t v12x1 = w12[0] >> 4; + const uint8_t v12x2 = w12[1] & 0xF; + const uint8_t v12x3 = w12[1] >> 4; + const uint8_t v12x4 = w12[2] & 0xF; + const uint8_t v12x5 = w12[2] >> 4; + const uint8_t v12x6 = w12[3] & 0xF; + const uint8_t v12x7 = w12[3] >> 4; + w12 += 4; + + ksum12 += (uint32_t) (v12x0); + ksum12 += (uint32_t) (v12x1); + ksum12 += (uint32_t) (v12x2); + ksum12 += (uint32_t) (v12x3); + ksum12 += (uint32_t) (v12x4); + ksum12 += (uint32_t) (v12x5); + ksum12 += (uint32_t) (v12x6); + ksum12 += (uint32_t) (v12x7); + // Subtract 8 zero points (8) + ksum12 -= 64; + + out[48] = (v12x0 | (v12x4 << 4)) ^ 0x88; + out[49] = (v12x1 | (v12x5 << 4)) ^ 0x88; + out[50] = (v12x2 | (v12x6 << 4)) ^ 0x88; + out[51] = (v12x3 | (v12x7 << 4)) ^ 0x88; + const uint8_t v13x0 = w13[0] & 0xF; + const uint8_t v13x1 = w13[0] >> 4; + const uint8_t v13x2 = w13[1] & 0xF; + const uint8_t v13x3 = w13[1] >> 4; + const uint8_t v13x4 = w13[2] & 0xF; + const uint8_t v13x5 = w13[2] >> 4; + const uint8_t v13x6 = w13[3] & 0xF; + const uint8_t v13x7 = w13[3] >> 4; + w13 += 4; + + ksum13 += (uint32_t) (v13x0); + ksum13 += (uint32_t) (v13x1); + ksum13 += (uint32_t) (v13x2); + ksum13 += (uint32_t) (v13x3); + ksum13 += (uint32_t) (v13x4); + ksum13 += (uint32_t) (v13x5); + ksum13 += (uint32_t) (v13x6); + ksum13 += (uint32_t) (v13x7); + // Subtract 8 zero points (8) + ksum13 -= 64; + + out[52] = (v13x0 | (v13x4 << 4)) ^ 0x88; + out[53] = (v13x1 | (v13x5 << 4)) ^ 0x88; + out[54] = (v13x2 | (v13x6 << 4)) ^ 0x88; + out[55] = (v13x3 | (v13x7 << 4)) ^ 0x88; + const uint8_t v14x0 = w14[0] & 0xF; + const uint8_t v14x1 = w14[0] >> 4; + const uint8_t v14x2 = w14[1] & 0xF; + const uint8_t v14x3 = w14[1] >> 4; + const uint8_t v14x4 = w14[2] & 0xF; + const uint8_t v14x5 = w14[2] >> 4; + const uint8_t v14x6 = w14[3] & 0xF; + const uint8_t v14x7 = w14[3] >> 4; + w14 += 4; + + ksum14 += (uint32_t) (v14x0); + ksum14 += (uint32_t) (v14x1); + ksum14 += (uint32_t) (v14x2); + ksum14 += (uint32_t) (v14x3); + ksum14 += (uint32_t) (v14x4); + ksum14 += (uint32_t) (v14x5); + ksum14 += (uint32_t) (v14x6); + ksum14 += (uint32_t) (v14x7); + // Subtract 8 zero points (8) + ksum14 -= 64; + + out[56] = (v14x0 | (v14x4 << 4)) ^ 0x88; + out[57] = (v14x1 | (v14x5 << 4)) ^ 0x88; + out[58] = (v14x2 | (v14x6 << 4)) ^ 0x88; + out[59] = (v14x3 | (v14x7 << 4)) ^ 0x88; + const uint8_t v15x0 = w15[0] & 0xF; + const uint8_t v15x1 = w15[0] >> 4; + const uint8_t v15x2 = w15[1] & 0xF; + const uint8_t v15x3 = w15[1] >> 4; + const uint8_t v15x4 = w15[2] & 0xF; + const uint8_t v15x5 = w15[2] >> 4; + const uint8_t v15x6 = w15[3] & 0xF; + const uint8_t v15x7 = w15[3] >> 4; + w15 += 4; + + ksum15 += (uint32_t) (v15x0); + ksum15 += (uint32_t) (v15x1); + ksum15 += (uint32_t) (v15x2); + ksum15 += (uint32_t) (v15x3); + ksum15 += (uint32_t) (v15x4); + ksum15 += (uint32_t) (v15x5); + ksum15 += (uint32_t) (v15x6); + ksum15 += (uint32_t) (v15x7); + // Subtract 8 zero points (8) + ksum15 -= 64; + + out[60] = (v15x0 | (v15x4 << 4)) ^ 0x88; + out[61] = (v15x1 | (v15x5 << 4)) ^ 0x88; + out[62] = (v15x2 | (v15x6 << 4)) ^ 0x88; + out[63] = (v15x3 | (v15x7 << 4)) ^ 0x88; + + out += 64; + } + float scale0 = math_cvt_fp32_bf16(s0[0]); + float scale1 = math_cvt_fp32_bf16(s1[0]); + float scale2 = math_cvt_fp32_bf16(s2[0]); + float scale3 = math_cvt_fp32_bf16(s3[0]); + float scale4 = math_cvt_fp32_bf16(s4[0]); + float scale5 = math_cvt_fp32_bf16(s5[0]); + float scale6 = math_cvt_fp32_bf16(s6[0]); + float scale7 = math_cvt_fp32_bf16(s7[0]); + float scale8 = math_cvt_fp32_bf16(s8[0]); + float scale9 = math_cvt_fp32_bf16(s9[0]); + float scale10 = math_cvt_fp32_bf16(s10[0]); + float scale11 = math_cvt_fp32_bf16(s11[0]); + float scale12 = math_cvt_fp32_bf16(s12[0]); + float scale13 = math_cvt_fp32_bf16(s13[0]); + float scale14 = math_cvt_fp32_bf16(s14[0]); + float scale15 = math_cvt_fp32_bf16(s15[0]); + s0 += 1; + s1 += 1; + s2 += 1; + s3 += 1; + s4 += 1; + s5 += 1; + s6 += 1; + s7 += 1; + s8 += 1; + s9 += 1; + s10 += 1; + s11 += 1; + s12 += 1; + s13 += 1; + s14 += 1; + s15 += 1; + + + packed_k_scaled_sum[0] -= (float)ksum0 * izp * scale0; + packed_k_scaled_sum[1] -= (float)ksum1 * izp * scale1; + packed_k_scaled_sum[2] -= (float)ksum2 * izp * scale2; + packed_k_scaled_sum[3] -= (float)ksum3 * izp * scale3; + packed_k_scaled_sum[4] -= (float)ksum4 * izp * scale4; + packed_k_scaled_sum[5] -= (float)ksum5 * izp * scale5; + packed_k_scaled_sum[6] -= (float)ksum6 * izp * scale6; + packed_k_scaled_sum[7] -= (float)ksum7 * izp * scale7; + packed_k_scaled_sum[8] -= (float)ksum8 * izp * scale8; + packed_k_scaled_sum[9] -= (float)ksum9 * izp * scale9; + packed_k_scaled_sum[10] -= (float)ksum10 * izp * scale10; + packed_k_scaled_sum[11] -= (float)ksum11 * izp * scale11; + packed_k_scaled_sum[12] -= (float)ksum12 * izp * scale12; + packed_k_scaled_sum[13] -= (float)ksum13 * izp * scale13; + packed_k_scaled_sum[14] -= (float)ksum14 * izp * scale14; + packed_k_scaled_sum[15] -= (float)ksum15 * izp * scale15; + + ((uint16_t*) out)[0] = math_cvt_bf16_fp32(scale0 / 16.0f); + ((uint16_t*) out)[1] = math_cvt_bf16_fp32(scale1 / 16.0f); + ((uint16_t*) out)[2] = math_cvt_bf16_fp32(scale2 / 16.0f); + ((uint16_t*) out)[3] = math_cvt_bf16_fp32(scale3 / 16.0f); + ((uint16_t*) out)[4] = math_cvt_bf16_fp32(scale4 / 16.0f); + ((uint16_t*) out)[5] = math_cvt_bf16_fp32(scale5 / 16.0f); + ((uint16_t*) out)[6] = math_cvt_bf16_fp32(scale6 / 16.0f); + ((uint16_t*) out)[7] = math_cvt_bf16_fp32(scale7 / 16.0f); + ((uint16_t*) out)[8] = math_cvt_bf16_fp32(scale8 / 16.0f); + ((uint16_t*) out)[9] = math_cvt_bf16_fp32(scale9 / 16.0f); + ((uint16_t*) out)[10] = math_cvt_bf16_fp32(scale10 / 16.0f); + ((uint16_t*) out)[11] = math_cvt_bf16_fp32(scale11 / 16.0f); + ((uint16_t*) out)[12] = math_cvt_bf16_fp32(scale12 / 16.0f); + ((uint16_t*) out)[13] = math_cvt_bf16_fp32(scale13 / 16.0f); + ((uint16_t*) out)[14] = math_cvt_bf16_fp32(scale14 / 16.0f); + ((uint16_t*) out)[15] = math_cvt_bf16_fp32(scale15 / 16.0f); + + out += 16 * sizeof(uint16_t); + } + + + if XNN_LIKELY(b != NULL){ + ((uint32_t*) out)[0] = b[0]; + ((uint32_t*) out)[1] = b[1]; + ((uint32_t*) out)[2] = b[2]; + ((uint32_t*) out)[3] = b[3]; + ((uint32_t*) out)[4] = b[4]; + ((uint32_t*) out)[5] = b[5]; + ((uint32_t*) out)[6] = b[6]; + ((uint32_t*) out)[7] = b[7]; + ((uint32_t*) out)[8] = b[8]; + ((uint32_t*) out)[9] = b[9]; + ((uint32_t*) out)[10] = b[10]; + ((uint32_t*) out)[11] = b[11]; + ((uint32_t*) out)[12] = b[12]; + ((uint32_t*) out)[13] = b[13]; + ((uint32_t*) out)[14] = b[14]; + ((uint32_t*) out)[15] = b[15]; + b += 16; + } else { + ((uint32_t*) out)[0] = 0; + ((uint32_t*) out)[1] = 0; + ((uint32_t*) out)[2] = 0; + ((uint32_t*) out)[3] = 0; + ((uint32_t*) out)[4] = 0; + ((uint32_t*) out)[5] = 0; + ((uint32_t*) out)[6] = 0; + ((uint32_t*) out)[7] = 0; + ((uint32_t*) out)[8] = 0; + ((uint32_t*) out)[9] = 0; + ((uint32_t*) out)[10] = 0; + ((uint32_t*) out)[11] = 0; + ((uint32_t*) out)[12] = 0; + ((uint32_t*) out)[13] = 0; + ((uint32_t*) out)[14] = 0; + ((uint32_t*) out)[15] = 0; + } + out += 16 * sizeof(uint32_t); + w0 = w15; + s0 = s15; + } + + // NC remainder (1..15) + if XNN_UNLIKELY(n != 0) { + float* packed_k_scaled_sum = (float*) out; + ((float*) out)[0] = 0; + ((float*) out)[1] = 0; + ((float*) out)[2] = 0; + ((float*) out)[3] = 0; + ((float*) out)[4] = 0; + ((float*) out)[5] = 0; + ((float*) out)[6] = 0; + ((float*) out)[7] = 0; + ((float*) out)[8] = 0; + ((float*) out)[9] = 0; + ((float*) out)[10] = 0; + ((float*) out)[11] = 0; + ((float*) out)[12] = 0; + ((float*) out)[13] = 0; + ((float*) out)[14] = 0; + ((float*) out)[15] = 0; + out += 16 * sizeof(float); + // NR remainder has less than 16 + const uint8_t* w1 = w0 + (kc >> 1); + const uint16_t* s1 = s0 + num_blocks; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + s1 = s0; + } + const uint8_t* w2 = w1 + (kc >> 1); + const uint16_t* s2 = s1 + num_blocks; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + s2 = s1; + } + const uint8_t* w3 = w2 + (kc >> 1); + const uint16_t* s3 = s2 + num_blocks; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + s3 = s2; + } + const uint8_t* w4 = w3 + (kc >> 1); + const uint16_t* s4 = s3 + num_blocks; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + s4 = s3; + } + const uint8_t* w5 = w4 + (kc >> 1); + const uint16_t* s5 = s4 + num_blocks; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + s5 = s4; + } + const uint8_t* w6 = w5 + (kc >> 1); + const uint16_t* s6 = s5 + num_blocks; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + s6 = s5; + } + const uint8_t* w7 = w6 + (kc >> 1); + const uint16_t* s7 = s6 + num_blocks; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + s7 = s6; + } + const uint8_t* w8 = w7 + (kc >> 1); + const uint16_t* s8 = s7 + num_blocks; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + s8 = s7; + } + const uint8_t* w9 = w8 + (kc >> 1); + const uint16_t* s9 = s8 + num_blocks; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + s9 = s8; + } + const uint8_t* w10 = w9 + (kc >> 1); + const uint16_t* s10 = s9 + num_blocks; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + s10 = s9; + } + const uint8_t* w11 = w10 + (kc >> 1); + const uint16_t* s11 = s10 + num_blocks; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + s11 = s10; + } + const uint8_t* w12 = w11 + (kc >> 1); + const uint16_t* s12 = s11 + num_blocks; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + s12 = s11; + } + const uint8_t* w13 = w12 + (kc >> 1); + const uint16_t* s13 = s12 + num_blocks; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + s13 = s12; + } + const uint8_t* w14 = w13 + (kc >> 1); + const uint16_t* s14 = s13 + num_blocks; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + s14 = s13; + } + + size_t kb = kc; + // Process k by blocks (bl) + for (; kb >= bl; kb-=bl) { + // Initialize KSum as subtracting bl zero points (8) + int32_t ksum0 = 0; + int32_t ksum1 = 0; + int32_t ksum2 = 0; + int32_t ksum3 = 0; + int32_t ksum4 = 0; + int32_t ksum5 = 0; + int32_t ksum6 = 0; + int32_t ksum7 = 0; + int32_t ksum8 = 0; + int32_t ksum9 = 0; + int32_t ksum10 = 0; + int32_t ksum11 = 0; + int32_t ksum12 = 0; + int32_t ksum13 = 0; + int32_t ksum14 = 0; + size_t k = bl; + for(; k >= 8; k-=8) { + const uint8_t v0x0 = w0[0] & 0xF; + const uint8_t v0x1 = w0[0] >> 4; + const uint8_t v0x2 = w0[1] & 0xF; + const uint8_t v0x3 = w0[1] >> 4; + const uint8_t v0x4 = w0[2] & 0xF; + const uint8_t v0x5 = w0[2] >> 4; + const uint8_t v0x6 = w0[3] & 0xF; + const uint8_t v0x7 = w0[3] >> 4; + w0 += 4; + + ksum0 += (uint32_t) (v0x0); + ksum0 += (uint32_t) (v0x1); + ksum0 += (uint32_t) (v0x2); + ksum0 += (uint32_t) (v0x3); + ksum0 += (uint32_t) (v0x4); + ksum0 += (uint32_t) (v0x5); + ksum0 += (uint32_t) (v0x6); + ksum0 += (uint32_t) (v0x7); + // Subtract 8 zero points (8) + ksum0 -= 64; + + out[0] = (v0x0 | (v0x4 << 4)) ^ 0x88; + out[1] = (v0x1 | (v0x5 << 4)) ^ 0x88; + out[2] = (v0x2 | (v0x6 << 4)) ^ 0x88; + out[3] = (v0x3 | (v0x7 << 4)) ^ 0x88; + const uint8_t v1x0 = w1[0] & 0xF; + const uint8_t v1x1 = w1[0] >> 4; + const uint8_t v1x2 = w1[1] & 0xF; + const uint8_t v1x3 = w1[1] >> 4; + const uint8_t v1x4 = w1[2] & 0xF; + const uint8_t v1x5 = w1[2] >> 4; + const uint8_t v1x6 = w1[3] & 0xF; + const uint8_t v1x7 = w1[3] >> 4; + w1 += 4; + + ksum1 += (uint32_t) (v1x0); + ksum1 += (uint32_t) (v1x1); + ksum1 += (uint32_t) (v1x2); + ksum1 += (uint32_t) (v1x3); + ksum1 += (uint32_t) (v1x4); + ksum1 += (uint32_t) (v1x5); + ksum1 += (uint32_t) (v1x6); + ksum1 += (uint32_t) (v1x7); + // Subtract 8 zero points (8) + ksum1 -= 64; + + out[4] = (v1x0 | (v1x4 << 4)) ^ 0x88; + out[5] = (v1x1 | (v1x5 << 4)) ^ 0x88; + out[6] = (v1x2 | (v1x6 << 4)) ^ 0x88; + out[7] = (v1x3 | (v1x7 << 4)) ^ 0x88; + const uint8_t v2x0 = w2[0] & 0xF; + const uint8_t v2x1 = w2[0] >> 4; + const uint8_t v2x2 = w2[1] & 0xF; + const uint8_t v2x3 = w2[1] >> 4; + const uint8_t v2x4 = w2[2] & 0xF; + const uint8_t v2x5 = w2[2] >> 4; + const uint8_t v2x6 = w2[3] & 0xF; + const uint8_t v2x7 = w2[3] >> 4; + w2 += 4; + + ksum2 += (uint32_t) (v2x0); + ksum2 += (uint32_t) (v2x1); + ksum2 += (uint32_t) (v2x2); + ksum2 += (uint32_t) (v2x3); + ksum2 += (uint32_t) (v2x4); + ksum2 += (uint32_t) (v2x5); + ksum2 += (uint32_t) (v2x6); + ksum2 += (uint32_t) (v2x7); + // Subtract 8 zero points (8) + ksum2 -= 64; + + out[8] = (v2x0 | (v2x4 << 4)) ^ 0x88; + out[9] = (v2x1 | (v2x5 << 4)) ^ 0x88; + out[10] = (v2x2 | (v2x6 << 4)) ^ 0x88; + out[11] = (v2x3 | (v2x7 << 4)) ^ 0x88; + const uint8_t v3x0 = w3[0] & 0xF; + const uint8_t v3x1 = w3[0] >> 4; + const uint8_t v3x2 = w3[1] & 0xF; + const uint8_t v3x3 = w3[1] >> 4; + const uint8_t v3x4 = w3[2] & 0xF; + const uint8_t v3x5 = w3[2] >> 4; + const uint8_t v3x6 = w3[3] & 0xF; + const uint8_t v3x7 = w3[3] >> 4; + w3 += 4; + + ksum3 += (uint32_t) (v3x0); + ksum3 += (uint32_t) (v3x1); + ksum3 += (uint32_t) (v3x2); + ksum3 += (uint32_t) (v3x3); + ksum3 += (uint32_t) (v3x4); + ksum3 += (uint32_t) (v3x5); + ksum3 += (uint32_t) (v3x6); + ksum3 += (uint32_t) (v3x7); + // Subtract 8 zero points (8) + ksum3 -= 64; + + out[12] = (v3x0 | (v3x4 << 4)) ^ 0x88; + out[13] = (v3x1 | (v3x5 << 4)) ^ 0x88; + out[14] = (v3x2 | (v3x6 << 4)) ^ 0x88; + out[15] = (v3x3 | (v3x7 << 4)) ^ 0x88; + const uint8_t v4x0 = w4[0] & 0xF; + const uint8_t v4x1 = w4[0] >> 4; + const uint8_t v4x2 = w4[1] & 0xF; + const uint8_t v4x3 = w4[1] >> 4; + const uint8_t v4x4 = w4[2] & 0xF; + const uint8_t v4x5 = w4[2] >> 4; + const uint8_t v4x6 = w4[3] & 0xF; + const uint8_t v4x7 = w4[3] >> 4; + w4 += 4; + + ksum4 += (uint32_t) (v4x0); + ksum4 += (uint32_t) (v4x1); + ksum4 += (uint32_t) (v4x2); + ksum4 += (uint32_t) (v4x3); + ksum4 += (uint32_t) (v4x4); + ksum4 += (uint32_t) (v4x5); + ksum4 += (uint32_t) (v4x6); + ksum4 += (uint32_t) (v4x7); + // Subtract 8 zero points (8) + ksum4 -= 64; + + out[16] = (v4x0 | (v4x4 << 4)) ^ 0x88; + out[17] = (v4x1 | (v4x5 << 4)) ^ 0x88; + out[18] = (v4x2 | (v4x6 << 4)) ^ 0x88; + out[19] = (v4x3 | (v4x7 << 4)) ^ 0x88; + const uint8_t v5x0 = w5[0] & 0xF; + const uint8_t v5x1 = w5[0] >> 4; + const uint8_t v5x2 = w5[1] & 0xF; + const uint8_t v5x3 = w5[1] >> 4; + const uint8_t v5x4 = w5[2] & 0xF; + const uint8_t v5x5 = w5[2] >> 4; + const uint8_t v5x6 = w5[3] & 0xF; + const uint8_t v5x7 = w5[3] >> 4; + w5 += 4; + + ksum5 += (uint32_t) (v5x0); + ksum5 += (uint32_t) (v5x1); + ksum5 += (uint32_t) (v5x2); + ksum5 += (uint32_t) (v5x3); + ksum5 += (uint32_t) (v5x4); + ksum5 += (uint32_t) (v5x5); + ksum5 += (uint32_t) (v5x6); + ksum5 += (uint32_t) (v5x7); + // Subtract 8 zero points (8) + ksum5 -= 64; + + out[20] = (v5x0 | (v5x4 << 4)) ^ 0x88; + out[21] = (v5x1 | (v5x5 << 4)) ^ 0x88; + out[22] = (v5x2 | (v5x6 << 4)) ^ 0x88; + out[23] = (v5x3 | (v5x7 << 4)) ^ 0x88; + const uint8_t v6x0 = w6[0] & 0xF; + const uint8_t v6x1 = w6[0] >> 4; + const uint8_t v6x2 = w6[1] & 0xF; + const uint8_t v6x3 = w6[1] >> 4; + const uint8_t v6x4 = w6[2] & 0xF; + const uint8_t v6x5 = w6[2] >> 4; + const uint8_t v6x6 = w6[3] & 0xF; + const uint8_t v6x7 = w6[3] >> 4; + w6 += 4; + + ksum6 += (uint32_t) (v6x0); + ksum6 += (uint32_t) (v6x1); + ksum6 += (uint32_t) (v6x2); + ksum6 += (uint32_t) (v6x3); + ksum6 += (uint32_t) (v6x4); + ksum6 += (uint32_t) (v6x5); + ksum6 += (uint32_t) (v6x6); + ksum6 += (uint32_t) (v6x7); + // Subtract 8 zero points (8) + ksum6 -= 64; + + out[24] = (v6x0 | (v6x4 << 4)) ^ 0x88; + out[25] = (v6x1 | (v6x5 << 4)) ^ 0x88; + out[26] = (v6x2 | (v6x6 << 4)) ^ 0x88; + out[27] = (v6x3 | (v6x7 << 4)) ^ 0x88; + const uint8_t v7x0 = w7[0] & 0xF; + const uint8_t v7x1 = w7[0] >> 4; + const uint8_t v7x2 = w7[1] & 0xF; + const uint8_t v7x3 = w7[1] >> 4; + const uint8_t v7x4 = w7[2] & 0xF; + const uint8_t v7x5 = w7[2] >> 4; + const uint8_t v7x6 = w7[3] & 0xF; + const uint8_t v7x7 = w7[3] >> 4; + w7 += 4; + + ksum7 += (uint32_t) (v7x0); + ksum7 += (uint32_t) (v7x1); + ksum7 += (uint32_t) (v7x2); + ksum7 += (uint32_t) (v7x3); + ksum7 += (uint32_t) (v7x4); + ksum7 += (uint32_t) (v7x5); + ksum7 += (uint32_t) (v7x6); + ksum7 += (uint32_t) (v7x7); + // Subtract 8 zero points (8) + ksum7 -= 64; + + out[28] = (v7x0 | (v7x4 << 4)) ^ 0x88; + out[29] = (v7x1 | (v7x5 << 4)) ^ 0x88; + out[30] = (v7x2 | (v7x6 << 4)) ^ 0x88; + out[31] = (v7x3 | (v7x7 << 4)) ^ 0x88; + const uint8_t v8x0 = w8[0] & 0xF; + const uint8_t v8x1 = w8[0] >> 4; + const uint8_t v8x2 = w8[1] & 0xF; + const uint8_t v8x3 = w8[1] >> 4; + const uint8_t v8x4 = w8[2] & 0xF; + const uint8_t v8x5 = w8[2] >> 4; + const uint8_t v8x6 = w8[3] & 0xF; + const uint8_t v8x7 = w8[3] >> 4; + w8 += 4; + + ksum8 += (uint32_t) (v8x0); + ksum8 += (uint32_t) (v8x1); + ksum8 += (uint32_t) (v8x2); + ksum8 += (uint32_t) (v8x3); + ksum8 += (uint32_t) (v8x4); + ksum8 += (uint32_t) (v8x5); + ksum8 += (uint32_t) (v8x6); + ksum8 += (uint32_t) (v8x7); + // Subtract 8 zero points (8) + ksum8 -= 64; + + out[32] = (v8x0 | (v8x4 << 4)) ^ 0x88; + out[33] = (v8x1 | (v8x5 << 4)) ^ 0x88; + out[34] = (v8x2 | (v8x6 << 4)) ^ 0x88; + out[35] = (v8x3 | (v8x7 << 4)) ^ 0x88; + const uint8_t v9x0 = w9[0] & 0xF; + const uint8_t v9x1 = w9[0] >> 4; + const uint8_t v9x2 = w9[1] & 0xF; + const uint8_t v9x3 = w9[1] >> 4; + const uint8_t v9x4 = w9[2] & 0xF; + const uint8_t v9x5 = w9[2] >> 4; + const uint8_t v9x6 = w9[3] & 0xF; + const uint8_t v9x7 = w9[3] >> 4; + w9 += 4; + + ksum9 += (uint32_t) (v9x0); + ksum9 += (uint32_t) (v9x1); + ksum9 += (uint32_t) (v9x2); + ksum9 += (uint32_t) (v9x3); + ksum9 += (uint32_t) (v9x4); + ksum9 += (uint32_t) (v9x5); + ksum9 += (uint32_t) (v9x6); + ksum9 += (uint32_t) (v9x7); + // Subtract 8 zero points (8) + ksum9 -= 64; + + out[36] = (v9x0 | (v9x4 << 4)) ^ 0x88; + out[37] = (v9x1 | (v9x5 << 4)) ^ 0x88; + out[38] = (v9x2 | (v9x6 << 4)) ^ 0x88; + out[39] = (v9x3 | (v9x7 << 4)) ^ 0x88; + const uint8_t v10x0 = w10[0] & 0xF; + const uint8_t v10x1 = w10[0] >> 4; + const uint8_t v10x2 = w10[1] & 0xF; + const uint8_t v10x3 = w10[1] >> 4; + const uint8_t v10x4 = w10[2] & 0xF; + const uint8_t v10x5 = w10[2] >> 4; + const uint8_t v10x6 = w10[3] & 0xF; + const uint8_t v10x7 = w10[3] >> 4; + w10 += 4; + + ksum10 += (uint32_t) (v10x0); + ksum10 += (uint32_t) (v10x1); + ksum10 += (uint32_t) (v10x2); + ksum10 += (uint32_t) (v10x3); + ksum10 += (uint32_t) (v10x4); + ksum10 += (uint32_t) (v10x5); + ksum10 += (uint32_t) (v10x6); + ksum10 += (uint32_t) (v10x7); + // Subtract 8 zero points (8) + ksum10 -= 64; + + out[40] = (v10x0 | (v10x4 << 4)) ^ 0x88; + out[41] = (v10x1 | (v10x5 << 4)) ^ 0x88; + out[42] = (v10x2 | (v10x6 << 4)) ^ 0x88; + out[43] = (v10x3 | (v10x7 << 4)) ^ 0x88; + const uint8_t v11x0 = w11[0] & 0xF; + const uint8_t v11x1 = w11[0] >> 4; + const uint8_t v11x2 = w11[1] & 0xF; + const uint8_t v11x3 = w11[1] >> 4; + const uint8_t v11x4 = w11[2] & 0xF; + const uint8_t v11x5 = w11[2] >> 4; + const uint8_t v11x6 = w11[3] & 0xF; + const uint8_t v11x7 = w11[3] >> 4; + w11 += 4; + + ksum11 += (uint32_t) (v11x0); + ksum11 += (uint32_t) (v11x1); + ksum11 += (uint32_t) (v11x2); + ksum11 += (uint32_t) (v11x3); + ksum11 += (uint32_t) (v11x4); + ksum11 += (uint32_t) (v11x5); + ksum11 += (uint32_t) (v11x6); + ksum11 += (uint32_t) (v11x7); + // Subtract 8 zero points (8) + ksum11 -= 64; + + out[44] = (v11x0 | (v11x4 << 4)) ^ 0x88; + out[45] = (v11x1 | (v11x5 << 4)) ^ 0x88; + out[46] = (v11x2 | (v11x6 << 4)) ^ 0x88; + out[47] = (v11x3 | (v11x7 << 4)) ^ 0x88; + const uint8_t v12x0 = w12[0] & 0xF; + const uint8_t v12x1 = w12[0] >> 4; + const uint8_t v12x2 = w12[1] & 0xF; + const uint8_t v12x3 = w12[1] >> 4; + const uint8_t v12x4 = w12[2] & 0xF; + const uint8_t v12x5 = w12[2] >> 4; + const uint8_t v12x6 = w12[3] & 0xF; + const uint8_t v12x7 = w12[3] >> 4; + w12 += 4; + + ksum12 += (uint32_t) (v12x0); + ksum12 += (uint32_t) (v12x1); + ksum12 += (uint32_t) (v12x2); + ksum12 += (uint32_t) (v12x3); + ksum12 += (uint32_t) (v12x4); + ksum12 += (uint32_t) (v12x5); + ksum12 += (uint32_t) (v12x6); + ksum12 += (uint32_t) (v12x7); + // Subtract 8 zero points (8) + ksum12 -= 64; + + out[48] = (v12x0 | (v12x4 << 4)) ^ 0x88; + out[49] = (v12x1 | (v12x5 << 4)) ^ 0x88; + out[50] = (v12x2 | (v12x6 << 4)) ^ 0x88; + out[51] = (v12x3 | (v12x7 << 4)) ^ 0x88; + const uint8_t v13x0 = w13[0] & 0xF; + const uint8_t v13x1 = w13[0] >> 4; + const uint8_t v13x2 = w13[1] & 0xF; + const uint8_t v13x3 = w13[1] >> 4; + const uint8_t v13x4 = w13[2] & 0xF; + const uint8_t v13x5 = w13[2] >> 4; + const uint8_t v13x6 = w13[3] & 0xF; + const uint8_t v13x7 = w13[3] >> 4; + w13 += 4; + + ksum13 += (uint32_t) (v13x0); + ksum13 += (uint32_t) (v13x1); + ksum13 += (uint32_t) (v13x2); + ksum13 += (uint32_t) (v13x3); + ksum13 += (uint32_t) (v13x4); + ksum13 += (uint32_t) (v13x5); + ksum13 += (uint32_t) (v13x6); + ksum13 += (uint32_t) (v13x7); + // Subtract 8 zero points (8) + ksum13 -= 64; + + out[52] = (v13x0 | (v13x4 << 4)) ^ 0x88; + out[53] = (v13x1 | (v13x5 << 4)) ^ 0x88; + out[54] = (v13x2 | (v13x6 << 4)) ^ 0x88; + out[55] = (v13x3 | (v13x7 << 4)) ^ 0x88; + const uint8_t v14x0 = w14[0] & 0xF; + const uint8_t v14x1 = w14[0] >> 4; + const uint8_t v14x2 = w14[1] & 0xF; + const uint8_t v14x3 = w14[1] >> 4; + const uint8_t v14x4 = w14[2] & 0xF; + const uint8_t v14x5 = w14[2] >> 4; + const uint8_t v14x6 = w14[3] & 0xF; + const uint8_t v14x7 = w14[3] >> 4; + w14 += 4; + + ksum14 += (uint32_t) (v14x0); + ksum14 += (uint32_t) (v14x1); + ksum14 += (uint32_t) (v14x2); + ksum14 += (uint32_t) (v14x3); + ksum14 += (uint32_t) (v14x4); + ksum14 += (uint32_t) (v14x5); + ksum14 += (uint32_t) (v14x6); + ksum14 += (uint32_t) (v14x7); + // Subtract 8 zero points (8) + ksum14 -= 64; + + out[56] = (v14x0 | (v14x4 << 4)) ^ 0x88; + out[57] = (v14x1 | (v14x5 << 4)) ^ 0x88; + out[58] = (v14x2 | (v14x6 << 4)) ^ 0x88; + out[59] = (v14x3 | (v14x7 << 4)) ^ 0x88; + + out += 64; + } + float scale0 = math_cvt_fp32_bf16(s0[0]); + float scale1 = math_cvt_fp32_bf16(s1[0]); + float scale2 = math_cvt_fp32_bf16(s2[0]); + float scale3 = math_cvt_fp32_bf16(s3[0]); + float scale4 = math_cvt_fp32_bf16(s4[0]); + float scale5 = math_cvt_fp32_bf16(s5[0]); + float scale6 = math_cvt_fp32_bf16(s6[0]); + float scale7 = math_cvt_fp32_bf16(s7[0]); + float scale8 = math_cvt_fp32_bf16(s8[0]); + float scale9 = math_cvt_fp32_bf16(s9[0]); + float scale10 = math_cvt_fp32_bf16(s10[0]); + float scale11 = math_cvt_fp32_bf16(s11[0]); + float scale12 = math_cvt_fp32_bf16(s12[0]); + float scale13 = math_cvt_fp32_bf16(s13[0]); + float scale14 = math_cvt_fp32_bf16(s14[0]); + s0 += 1; + s1 += 1; + s2 += 1; + s3 += 1; + s4 += 1; + s5 += 1; + s6 += 1; + s7 += 1; + s8 += 1; + s9 += 1; + s10 += 1; + s11 += 1; + s12 += 1; + s13 += 1; + s14 += 1; + + + packed_k_scaled_sum[0] -= (float)ksum0 * izp * scale0; + packed_k_scaled_sum[1] -= (float)ksum1 * izp * scale1; + packed_k_scaled_sum[2] -= (float)ksum2 * izp * scale2; + packed_k_scaled_sum[3] -= (float)ksum3 * izp * scale3; + packed_k_scaled_sum[4] -= (float)ksum4 * izp * scale4; + packed_k_scaled_sum[5] -= (float)ksum5 * izp * scale5; + packed_k_scaled_sum[6] -= (float)ksum6 * izp * scale6; + packed_k_scaled_sum[7] -= (float)ksum7 * izp * scale7; + packed_k_scaled_sum[8] -= (float)ksum8 * izp * scale8; + packed_k_scaled_sum[9] -= (float)ksum9 * izp * scale9; + packed_k_scaled_sum[10] -= (float)ksum10 * izp * scale10; + packed_k_scaled_sum[11] -= (float)ksum11 * izp * scale11; + packed_k_scaled_sum[12] -= (float)ksum12 * izp * scale12; + packed_k_scaled_sum[13] -= (float)ksum13 * izp * scale13; + packed_k_scaled_sum[14] -= (float)ksum14 * izp * scale14; + + ((uint16_t*) out)[0] = math_cvt_bf16_fp32(scale0 / 16.0f); + ((uint16_t*) out)[1] = math_cvt_bf16_fp32(scale1 / 16.0f); + ((uint16_t*) out)[2] = math_cvt_bf16_fp32(scale2 / 16.0f); + ((uint16_t*) out)[3] = math_cvt_bf16_fp32(scale3 / 16.0f); + ((uint16_t*) out)[4] = math_cvt_bf16_fp32(scale4 / 16.0f); + ((uint16_t*) out)[5] = math_cvt_bf16_fp32(scale5 / 16.0f); + ((uint16_t*) out)[6] = math_cvt_bf16_fp32(scale6 / 16.0f); + ((uint16_t*) out)[7] = math_cvt_bf16_fp32(scale7 / 16.0f); + ((uint16_t*) out)[8] = math_cvt_bf16_fp32(scale8 / 16.0f); + ((uint16_t*) out)[9] = math_cvt_bf16_fp32(scale9 / 16.0f); + ((uint16_t*) out)[10] = math_cvt_bf16_fp32(scale10 / 16.0f); + ((uint16_t*) out)[11] = math_cvt_bf16_fp32(scale11 / 16.0f); + ((uint16_t*) out)[12] = math_cvt_bf16_fp32(scale12 / 16.0f); + ((uint16_t*) out)[13] = math_cvt_bf16_fp32(scale13 / 16.0f); + ((uint16_t*) out)[14] = math_cvt_bf16_fp32(scale14 / 16.0f); + + out += 16 * sizeof(uint16_t); + } + + + if XNN_LIKELY(b != NULL){ + size_t nb = n; + do { + *((uint32_t*) out) = *b++; + out += sizeof(uint32_t); + } while(--nb != 0); + } else { + size_t nb = n; + do { + *((uint32_t*) out) = 0; + out += sizeof(uint32_t); + } while(--nb != 0); + } + out += 16 * sizeof(uint32_t); + } + } while (--g != 0); +} diff --git a/src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c b/src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c new file mode 100644 index 00000000000..42b9fea148b --- /dev/null +++ b/src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c @@ -0,0 +1,1858 @@ +// Auto-generated file. Do not edit! +// Template: src/qb4-packw/kr-scalar.c.in +// Generator: tools/xngen +// +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "xnnpack/packw.h" + +void xnn_qb4_packw_gemm_goi_ukernel_x16c8__scalar( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + size_t bl, + const uint8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes_bl, + size_t extra_bytes_n, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + assert(extra_bytes_bl == nr * sizeof(uint16_t)); + assert(extra_bytes_n == nr * sizeof(float)); + assert(params != NULL); + assert(kc % bl == 0); + size_t num_blocks = kc / bl; + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + const uint32_t izp = (uint32_t) (((const struct xnn_qs8_qc4w_packing_params*) params)->input_zero_point + 0); + + do { + // NC main loop multiple of 16 + const uint8_t* w0 = (const uint8_t*) weights; + const uint16_t* s0 = (const uint16_t*) scale; + size_t n = nc; + for (;n >= 16; n -= 16) { + float* packed_k_scaled_sum = (float*) out; + ((float*) out)[0] = 0; + ((float*) out)[1] = 0; + ((float*) out)[2] = 0; + ((float*) out)[3] = 0; + ((float*) out)[4] = 0; + ((float*) out)[5] = 0; + ((float*) out)[6] = 0; + ((float*) out)[7] = 0; + ((float*) out)[8] = 0; + ((float*) out)[9] = 0; + ((float*) out)[10] = 0; + ((float*) out)[11] = 0; + ((float*) out)[12] = 0; + ((float*) out)[13] = 0; + ((float*) out)[14] = 0; + ((float*) out)[15] = 0; + out += 16 * sizeof(float); + + // KC/2 bytes is KC Nibbles + const uint8_t* w1 = w0 + (kc >> 1); + const uint8_t* w2 = w1 + (kc >> 1); + const uint8_t* w3 = w2 + (kc >> 1); + const uint8_t* w4 = w3 + (kc >> 1); + const uint8_t* w5 = w4 + (kc >> 1); + const uint8_t* w6 = w5 + (kc >> 1); + const uint8_t* w7 = w6 + (kc >> 1); + const uint8_t* w8 = w7 + (kc >> 1); + const uint8_t* w9 = w8 + (kc >> 1); + const uint8_t* w10 = w9 + (kc >> 1); + const uint8_t* w11 = w10 + (kc >> 1); + const uint8_t* w12 = w11 + (kc >> 1); + const uint8_t* w13 = w12 + (kc >> 1); + const uint8_t* w14 = w13 + (kc >> 1); + const uint8_t* w15 = w14 + (kc >> 1); + + // scales + const uint16_t* s1 = s0 + num_blocks; + const uint16_t* s2 = s1 + num_blocks; + const uint16_t* s3 = s2 + num_blocks; + const uint16_t* s4 = s3 + num_blocks; + const uint16_t* s5 = s4 + num_blocks; + const uint16_t* s6 = s5 + num_blocks; + const uint16_t* s7 = s6 + num_blocks; + const uint16_t* s8 = s7 + num_blocks; + const uint16_t* s9 = s8 + num_blocks; + const uint16_t* s10 = s9 + num_blocks; + const uint16_t* s11 = s10 + num_blocks; + const uint16_t* s12 = s11 + num_blocks; + const uint16_t* s13 = s12 + num_blocks; + const uint16_t* s14 = s13 + num_blocks; + const uint16_t* s15 = s14 + num_blocks; + + + size_t kb = kc; + // Process k by blocks (bl) + for (; kb >= bl; kb-=bl) { + // Initialize KSum as subtracting bl zero points (8) + int32_t ksum0 = 0; + int32_t ksum1 = 0; + int32_t ksum2 = 0; + int32_t ksum3 = 0; + int32_t ksum4 = 0; + int32_t ksum5 = 0; + int32_t ksum6 = 0; + int32_t ksum7 = 0; + int32_t ksum8 = 0; + int32_t ksum9 = 0; + int32_t ksum10 = 0; + int32_t ksum11 = 0; + int32_t ksum12 = 0; + int32_t ksum13 = 0; + int32_t ksum14 = 0; + int32_t ksum15 = 0; + size_t k = bl; + for(; k >= 16; k-=16) { + const uint8_t v0x0 = w0[0] & 0xF; + const uint8_t v0x1 = w0[0] >> 4; + const uint8_t v0x2 = w0[1] & 0xF; + const uint8_t v0x3 = w0[1] >> 4; + const uint8_t v0x4 = w0[2] & 0xF; + const uint8_t v0x5 = w0[2] >> 4; + const uint8_t v0x6 = w0[3] & 0xF; + const uint8_t v0x7 = w0[3] >> 4; + const uint8_t v0x8 = w0[4] & 0xF; + const uint8_t v0x9 = w0[4] >> 4; + const uint8_t v0x10 = w0[5] & 0xF; + const uint8_t v0x11 = w0[5] >> 4; + const uint8_t v0x12 = w0[6] & 0xF; + const uint8_t v0x13 = w0[6] >> 4; + const uint8_t v0x14 = w0[7] & 0xF; + const uint8_t v0x15 = w0[7] >> 4; + w0 += 8; + + ksum0 += (uint32_t) (v0x0); + ksum0 += (uint32_t) (v0x1); + ksum0 += (uint32_t) (v0x2); + ksum0 += (uint32_t) (v0x3); + ksum0 += (uint32_t) (v0x4); + ksum0 += (uint32_t) (v0x5); + ksum0 += (uint32_t) (v0x6); + ksum0 += (uint32_t) (v0x7); + ksum0 += (uint32_t) (v0x8); + ksum0 += (uint32_t) (v0x9); + ksum0 += (uint32_t) (v0x10); + ksum0 += (uint32_t) (v0x11); + ksum0 += (uint32_t) (v0x12); + ksum0 += (uint32_t) (v0x13); + ksum0 += (uint32_t) (v0x14); + ksum0 += (uint32_t) (v0x15); + // Subtract 16 zero points (8) + ksum0 -= 128; + + out[0] = (v0x0 | (v0x8 << 4)) ^ 0x88; + out[1] = (v0x1 | (v0x9 << 4)) ^ 0x88; + out[2] = (v0x2 | (v0x10 << 4)) ^ 0x88; + out[3] = (v0x3 | (v0x11 << 4)) ^ 0x88; + out[4] = (v0x4 | (v0x12 << 4)) ^ 0x88; + out[5] = (v0x5 | (v0x13 << 4)) ^ 0x88; + out[6] = (v0x6 | (v0x14 << 4)) ^ 0x88; + out[7] = (v0x7 | (v0x15 << 4)) ^ 0x88; + const uint8_t v1x0 = w1[0] & 0xF; + const uint8_t v1x1 = w1[0] >> 4; + const uint8_t v1x2 = w1[1] & 0xF; + const uint8_t v1x3 = w1[1] >> 4; + const uint8_t v1x4 = w1[2] & 0xF; + const uint8_t v1x5 = w1[2] >> 4; + const uint8_t v1x6 = w1[3] & 0xF; + const uint8_t v1x7 = w1[3] >> 4; + const uint8_t v1x8 = w1[4] & 0xF; + const uint8_t v1x9 = w1[4] >> 4; + const uint8_t v1x10 = w1[5] & 0xF; + const uint8_t v1x11 = w1[5] >> 4; + const uint8_t v1x12 = w1[6] & 0xF; + const uint8_t v1x13 = w1[6] >> 4; + const uint8_t v1x14 = w1[7] & 0xF; + const uint8_t v1x15 = w1[7] >> 4; + w1 += 8; + + ksum1 += (uint32_t) (v1x0); + ksum1 += (uint32_t) (v1x1); + ksum1 += (uint32_t) (v1x2); + ksum1 += (uint32_t) (v1x3); + ksum1 += (uint32_t) (v1x4); + ksum1 += (uint32_t) (v1x5); + ksum1 += (uint32_t) (v1x6); + ksum1 += (uint32_t) (v1x7); + ksum1 += (uint32_t) (v1x8); + ksum1 += (uint32_t) (v1x9); + ksum1 += (uint32_t) (v1x10); + ksum1 += (uint32_t) (v1x11); + ksum1 += (uint32_t) (v1x12); + ksum1 += (uint32_t) (v1x13); + ksum1 += (uint32_t) (v1x14); + ksum1 += (uint32_t) (v1x15); + // Subtract 16 zero points (8) + ksum1 -= 128; + + out[8] = (v1x0 | (v1x8 << 4)) ^ 0x88; + out[9] = (v1x1 | (v1x9 << 4)) ^ 0x88; + out[10] = (v1x2 | (v1x10 << 4)) ^ 0x88; + out[11] = (v1x3 | (v1x11 << 4)) ^ 0x88; + out[12] = (v1x4 | (v1x12 << 4)) ^ 0x88; + out[13] = (v1x5 | (v1x13 << 4)) ^ 0x88; + out[14] = (v1x6 | (v1x14 << 4)) ^ 0x88; + out[15] = (v1x7 | (v1x15 << 4)) ^ 0x88; + const uint8_t v2x0 = w2[0] & 0xF; + const uint8_t v2x1 = w2[0] >> 4; + const uint8_t v2x2 = w2[1] & 0xF; + const uint8_t v2x3 = w2[1] >> 4; + const uint8_t v2x4 = w2[2] & 0xF; + const uint8_t v2x5 = w2[2] >> 4; + const uint8_t v2x6 = w2[3] & 0xF; + const uint8_t v2x7 = w2[3] >> 4; + const uint8_t v2x8 = w2[4] & 0xF; + const uint8_t v2x9 = w2[4] >> 4; + const uint8_t v2x10 = w2[5] & 0xF; + const uint8_t v2x11 = w2[5] >> 4; + const uint8_t v2x12 = w2[6] & 0xF; + const uint8_t v2x13 = w2[6] >> 4; + const uint8_t v2x14 = w2[7] & 0xF; + const uint8_t v2x15 = w2[7] >> 4; + w2 += 8; + + ksum2 += (uint32_t) (v2x0); + ksum2 += (uint32_t) (v2x1); + ksum2 += (uint32_t) (v2x2); + ksum2 += (uint32_t) (v2x3); + ksum2 += (uint32_t) (v2x4); + ksum2 += (uint32_t) (v2x5); + ksum2 += (uint32_t) (v2x6); + ksum2 += (uint32_t) (v2x7); + ksum2 += (uint32_t) (v2x8); + ksum2 += (uint32_t) (v2x9); + ksum2 += (uint32_t) (v2x10); + ksum2 += (uint32_t) (v2x11); + ksum2 += (uint32_t) (v2x12); + ksum2 += (uint32_t) (v2x13); + ksum2 += (uint32_t) (v2x14); + ksum2 += (uint32_t) (v2x15); + // Subtract 16 zero points (8) + ksum2 -= 128; + + out[16] = (v2x0 | (v2x8 << 4)) ^ 0x88; + out[17] = (v2x1 | (v2x9 << 4)) ^ 0x88; + out[18] = (v2x2 | (v2x10 << 4)) ^ 0x88; + out[19] = (v2x3 | (v2x11 << 4)) ^ 0x88; + out[20] = (v2x4 | (v2x12 << 4)) ^ 0x88; + out[21] = (v2x5 | (v2x13 << 4)) ^ 0x88; + out[22] = (v2x6 | (v2x14 << 4)) ^ 0x88; + out[23] = (v2x7 | (v2x15 << 4)) ^ 0x88; + const uint8_t v3x0 = w3[0] & 0xF; + const uint8_t v3x1 = w3[0] >> 4; + const uint8_t v3x2 = w3[1] & 0xF; + const uint8_t v3x3 = w3[1] >> 4; + const uint8_t v3x4 = w3[2] & 0xF; + const uint8_t v3x5 = w3[2] >> 4; + const uint8_t v3x6 = w3[3] & 0xF; + const uint8_t v3x7 = w3[3] >> 4; + const uint8_t v3x8 = w3[4] & 0xF; + const uint8_t v3x9 = w3[4] >> 4; + const uint8_t v3x10 = w3[5] & 0xF; + const uint8_t v3x11 = w3[5] >> 4; + const uint8_t v3x12 = w3[6] & 0xF; + const uint8_t v3x13 = w3[6] >> 4; + const uint8_t v3x14 = w3[7] & 0xF; + const uint8_t v3x15 = w3[7] >> 4; + w3 += 8; + + ksum3 += (uint32_t) (v3x0); + ksum3 += (uint32_t) (v3x1); + ksum3 += (uint32_t) (v3x2); + ksum3 += (uint32_t) (v3x3); + ksum3 += (uint32_t) (v3x4); + ksum3 += (uint32_t) (v3x5); + ksum3 += (uint32_t) (v3x6); + ksum3 += (uint32_t) (v3x7); + ksum3 += (uint32_t) (v3x8); + ksum3 += (uint32_t) (v3x9); + ksum3 += (uint32_t) (v3x10); + ksum3 += (uint32_t) (v3x11); + ksum3 += (uint32_t) (v3x12); + ksum3 += (uint32_t) (v3x13); + ksum3 += (uint32_t) (v3x14); + ksum3 += (uint32_t) (v3x15); + // Subtract 16 zero points (8) + ksum3 -= 128; + + out[24] = (v3x0 | (v3x8 << 4)) ^ 0x88; + out[25] = (v3x1 | (v3x9 << 4)) ^ 0x88; + out[26] = (v3x2 | (v3x10 << 4)) ^ 0x88; + out[27] = (v3x3 | (v3x11 << 4)) ^ 0x88; + out[28] = (v3x4 | (v3x12 << 4)) ^ 0x88; + out[29] = (v3x5 | (v3x13 << 4)) ^ 0x88; + out[30] = (v3x6 | (v3x14 << 4)) ^ 0x88; + out[31] = (v3x7 | (v3x15 << 4)) ^ 0x88; + const uint8_t v4x0 = w4[0] & 0xF; + const uint8_t v4x1 = w4[0] >> 4; + const uint8_t v4x2 = w4[1] & 0xF; + const uint8_t v4x3 = w4[1] >> 4; + const uint8_t v4x4 = w4[2] & 0xF; + const uint8_t v4x5 = w4[2] >> 4; + const uint8_t v4x6 = w4[3] & 0xF; + const uint8_t v4x7 = w4[3] >> 4; + const uint8_t v4x8 = w4[4] & 0xF; + const uint8_t v4x9 = w4[4] >> 4; + const uint8_t v4x10 = w4[5] & 0xF; + const uint8_t v4x11 = w4[5] >> 4; + const uint8_t v4x12 = w4[6] & 0xF; + const uint8_t v4x13 = w4[6] >> 4; + const uint8_t v4x14 = w4[7] & 0xF; + const uint8_t v4x15 = w4[7] >> 4; + w4 += 8; + + ksum4 += (uint32_t) (v4x0); + ksum4 += (uint32_t) (v4x1); + ksum4 += (uint32_t) (v4x2); + ksum4 += (uint32_t) (v4x3); + ksum4 += (uint32_t) (v4x4); + ksum4 += (uint32_t) (v4x5); + ksum4 += (uint32_t) (v4x6); + ksum4 += (uint32_t) (v4x7); + ksum4 += (uint32_t) (v4x8); + ksum4 += (uint32_t) (v4x9); + ksum4 += (uint32_t) (v4x10); + ksum4 += (uint32_t) (v4x11); + ksum4 += (uint32_t) (v4x12); + ksum4 += (uint32_t) (v4x13); + ksum4 += (uint32_t) (v4x14); + ksum4 += (uint32_t) (v4x15); + // Subtract 16 zero points (8) + ksum4 -= 128; + + out[32] = (v4x0 | (v4x8 << 4)) ^ 0x88; + out[33] = (v4x1 | (v4x9 << 4)) ^ 0x88; + out[34] = (v4x2 | (v4x10 << 4)) ^ 0x88; + out[35] = (v4x3 | (v4x11 << 4)) ^ 0x88; + out[36] = (v4x4 | (v4x12 << 4)) ^ 0x88; + out[37] = (v4x5 | (v4x13 << 4)) ^ 0x88; + out[38] = (v4x6 | (v4x14 << 4)) ^ 0x88; + out[39] = (v4x7 | (v4x15 << 4)) ^ 0x88; + const uint8_t v5x0 = w5[0] & 0xF; + const uint8_t v5x1 = w5[0] >> 4; + const uint8_t v5x2 = w5[1] & 0xF; + const uint8_t v5x3 = w5[1] >> 4; + const uint8_t v5x4 = w5[2] & 0xF; + const uint8_t v5x5 = w5[2] >> 4; + const uint8_t v5x6 = w5[3] & 0xF; + const uint8_t v5x7 = w5[3] >> 4; + const uint8_t v5x8 = w5[4] & 0xF; + const uint8_t v5x9 = w5[4] >> 4; + const uint8_t v5x10 = w5[5] & 0xF; + const uint8_t v5x11 = w5[5] >> 4; + const uint8_t v5x12 = w5[6] & 0xF; + const uint8_t v5x13 = w5[6] >> 4; + const uint8_t v5x14 = w5[7] & 0xF; + const uint8_t v5x15 = w5[7] >> 4; + w5 += 8; + + ksum5 += (uint32_t) (v5x0); + ksum5 += (uint32_t) (v5x1); + ksum5 += (uint32_t) (v5x2); + ksum5 += (uint32_t) (v5x3); + ksum5 += (uint32_t) (v5x4); + ksum5 += (uint32_t) (v5x5); + ksum5 += (uint32_t) (v5x6); + ksum5 += (uint32_t) (v5x7); + ksum5 += (uint32_t) (v5x8); + ksum5 += (uint32_t) (v5x9); + ksum5 += (uint32_t) (v5x10); + ksum5 += (uint32_t) (v5x11); + ksum5 += (uint32_t) (v5x12); + ksum5 += (uint32_t) (v5x13); + ksum5 += (uint32_t) (v5x14); + ksum5 += (uint32_t) (v5x15); + // Subtract 16 zero points (8) + ksum5 -= 128; + + out[40] = (v5x0 | (v5x8 << 4)) ^ 0x88; + out[41] = (v5x1 | (v5x9 << 4)) ^ 0x88; + out[42] = (v5x2 | (v5x10 << 4)) ^ 0x88; + out[43] = (v5x3 | (v5x11 << 4)) ^ 0x88; + out[44] = (v5x4 | (v5x12 << 4)) ^ 0x88; + out[45] = (v5x5 | (v5x13 << 4)) ^ 0x88; + out[46] = (v5x6 | (v5x14 << 4)) ^ 0x88; + out[47] = (v5x7 | (v5x15 << 4)) ^ 0x88; + const uint8_t v6x0 = w6[0] & 0xF; + const uint8_t v6x1 = w6[0] >> 4; + const uint8_t v6x2 = w6[1] & 0xF; + const uint8_t v6x3 = w6[1] >> 4; + const uint8_t v6x4 = w6[2] & 0xF; + const uint8_t v6x5 = w6[2] >> 4; + const uint8_t v6x6 = w6[3] & 0xF; + const uint8_t v6x7 = w6[3] >> 4; + const uint8_t v6x8 = w6[4] & 0xF; + const uint8_t v6x9 = w6[4] >> 4; + const uint8_t v6x10 = w6[5] & 0xF; + const uint8_t v6x11 = w6[5] >> 4; + const uint8_t v6x12 = w6[6] & 0xF; + const uint8_t v6x13 = w6[6] >> 4; + const uint8_t v6x14 = w6[7] & 0xF; + const uint8_t v6x15 = w6[7] >> 4; + w6 += 8; + + ksum6 += (uint32_t) (v6x0); + ksum6 += (uint32_t) (v6x1); + ksum6 += (uint32_t) (v6x2); + ksum6 += (uint32_t) (v6x3); + ksum6 += (uint32_t) (v6x4); + ksum6 += (uint32_t) (v6x5); + ksum6 += (uint32_t) (v6x6); + ksum6 += (uint32_t) (v6x7); + ksum6 += (uint32_t) (v6x8); + ksum6 += (uint32_t) (v6x9); + ksum6 += (uint32_t) (v6x10); + ksum6 += (uint32_t) (v6x11); + ksum6 += (uint32_t) (v6x12); + ksum6 += (uint32_t) (v6x13); + ksum6 += (uint32_t) (v6x14); + ksum6 += (uint32_t) (v6x15); + // Subtract 16 zero points (8) + ksum6 -= 128; + + out[48] = (v6x0 | (v6x8 << 4)) ^ 0x88; + out[49] = (v6x1 | (v6x9 << 4)) ^ 0x88; + out[50] = (v6x2 | (v6x10 << 4)) ^ 0x88; + out[51] = (v6x3 | (v6x11 << 4)) ^ 0x88; + out[52] = (v6x4 | (v6x12 << 4)) ^ 0x88; + out[53] = (v6x5 | (v6x13 << 4)) ^ 0x88; + out[54] = (v6x6 | (v6x14 << 4)) ^ 0x88; + out[55] = (v6x7 | (v6x15 << 4)) ^ 0x88; + const uint8_t v7x0 = w7[0] & 0xF; + const uint8_t v7x1 = w7[0] >> 4; + const uint8_t v7x2 = w7[1] & 0xF; + const uint8_t v7x3 = w7[1] >> 4; + const uint8_t v7x4 = w7[2] & 0xF; + const uint8_t v7x5 = w7[2] >> 4; + const uint8_t v7x6 = w7[3] & 0xF; + const uint8_t v7x7 = w7[3] >> 4; + const uint8_t v7x8 = w7[4] & 0xF; + const uint8_t v7x9 = w7[4] >> 4; + const uint8_t v7x10 = w7[5] & 0xF; + const uint8_t v7x11 = w7[5] >> 4; + const uint8_t v7x12 = w7[6] & 0xF; + const uint8_t v7x13 = w7[6] >> 4; + const uint8_t v7x14 = w7[7] & 0xF; + const uint8_t v7x15 = w7[7] >> 4; + w7 += 8; + + ksum7 += (uint32_t) (v7x0); + ksum7 += (uint32_t) (v7x1); + ksum7 += (uint32_t) (v7x2); + ksum7 += (uint32_t) (v7x3); + ksum7 += (uint32_t) (v7x4); + ksum7 += (uint32_t) (v7x5); + ksum7 += (uint32_t) (v7x6); + ksum7 += (uint32_t) (v7x7); + ksum7 += (uint32_t) (v7x8); + ksum7 += (uint32_t) (v7x9); + ksum7 += (uint32_t) (v7x10); + ksum7 += (uint32_t) (v7x11); + ksum7 += (uint32_t) (v7x12); + ksum7 += (uint32_t) (v7x13); + ksum7 += (uint32_t) (v7x14); + ksum7 += (uint32_t) (v7x15); + // Subtract 16 zero points (8) + ksum7 -= 128; + + out[56] = (v7x0 | (v7x8 << 4)) ^ 0x88; + out[57] = (v7x1 | (v7x9 << 4)) ^ 0x88; + out[58] = (v7x2 | (v7x10 << 4)) ^ 0x88; + out[59] = (v7x3 | (v7x11 << 4)) ^ 0x88; + out[60] = (v7x4 | (v7x12 << 4)) ^ 0x88; + out[61] = (v7x5 | (v7x13 << 4)) ^ 0x88; + out[62] = (v7x6 | (v7x14 << 4)) ^ 0x88; + out[63] = (v7x7 | (v7x15 << 4)) ^ 0x88; + const uint8_t v8x0 = w8[0] & 0xF; + const uint8_t v8x1 = w8[0] >> 4; + const uint8_t v8x2 = w8[1] & 0xF; + const uint8_t v8x3 = w8[1] >> 4; + const uint8_t v8x4 = w8[2] & 0xF; + const uint8_t v8x5 = w8[2] >> 4; + const uint8_t v8x6 = w8[3] & 0xF; + const uint8_t v8x7 = w8[3] >> 4; + const uint8_t v8x8 = w8[4] & 0xF; + const uint8_t v8x9 = w8[4] >> 4; + const uint8_t v8x10 = w8[5] & 0xF; + const uint8_t v8x11 = w8[5] >> 4; + const uint8_t v8x12 = w8[6] & 0xF; + const uint8_t v8x13 = w8[6] >> 4; + const uint8_t v8x14 = w8[7] & 0xF; + const uint8_t v8x15 = w8[7] >> 4; + w8 += 8; + + ksum8 += (uint32_t) (v8x0); + ksum8 += (uint32_t) (v8x1); + ksum8 += (uint32_t) (v8x2); + ksum8 += (uint32_t) (v8x3); + ksum8 += (uint32_t) (v8x4); + ksum8 += (uint32_t) (v8x5); + ksum8 += (uint32_t) (v8x6); + ksum8 += (uint32_t) (v8x7); + ksum8 += (uint32_t) (v8x8); + ksum8 += (uint32_t) (v8x9); + ksum8 += (uint32_t) (v8x10); + ksum8 += (uint32_t) (v8x11); + ksum8 += (uint32_t) (v8x12); + ksum8 += (uint32_t) (v8x13); + ksum8 += (uint32_t) (v8x14); + ksum8 += (uint32_t) (v8x15); + // Subtract 16 zero points (8) + ksum8 -= 128; + + out[64] = (v8x0 | (v8x8 << 4)) ^ 0x88; + out[65] = (v8x1 | (v8x9 << 4)) ^ 0x88; + out[66] = (v8x2 | (v8x10 << 4)) ^ 0x88; + out[67] = (v8x3 | (v8x11 << 4)) ^ 0x88; + out[68] = (v8x4 | (v8x12 << 4)) ^ 0x88; + out[69] = (v8x5 | (v8x13 << 4)) ^ 0x88; + out[70] = (v8x6 | (v8x14 << 4)) ^ 0x88; + out[71] = (v8x7 | (v8x15 << 4)) ^ 0x88; + const uint8_t v9x0 = w9[0] & 0xF; + const uint8_t v9x1 = w9[0] >> 4; + const uint8_t v9x2 = w9[1] & 0xF; + const uint8_t v9x3 = w9[1] >> 4; + const uint8_t v9x4 = w9[2] & 0xF; + const uint8_t v9x5 = w9[2] >> 4; + const uint8_t v9x6 = w9[3] & 0xF; + const uint8_t v9x7 = w9[3] >> 4; + const uint8_t v9x8 = w9[4] & 0xF; + const uint8_t v9x9 = w9[4] >> 4; + const uint8_t v9x10 = w9[5] & 0xF; + const uint8_t v9x11 = w9[5] >> 4; + const uint8_t v9x12 = w9[6] & 0xF; + const uint8_t v9x13 = w9[6] >> 4; + const uint8_t v9x14 = w9[7] & 0xF; + const uint8_t v9x15 = w9[7] >> 4; + w9 += 8; + + ksum9 += (uint32_t) (v9x0); + ksum9 += (uint32_t) (v9x1); + ksum9 += (uint32_t) (v9x2); + ksum9 += (uint32_t) (v9x3); + ksum9 += (uint32_t) (v9x4); + ksum9 += (uint32_t) (v9x5); + ksum9 += (uint32_t) (v9x6); + ksum9 += (uint32_t) (v9x7); + ksum9 += (uint32_t) (v9x8); + ksum9 += (uint32_t) (v9x9); + ksum9 += (uint32_t) (v9x10); + ksum9 += (uint32_t) (v9x11); + ksum9 += (uint32_t) (v9x12); + ksum9 += (uint32_t) (v9x13); + ksum9 += (uint32_t) (v9x14); + ksum9 += (uint32_t) (v9x15); + // Subtract 16 zero points (8) + ksum9 -= 128; + + out[72] = (v9x0 | (v9x8 << 4)) ^ 0x88; + out[73] = (v9x1 | (v9x9 << 4)) ^ 0x88; + out[74] = (v9x2 | (v9x10 << 4)) ^ 0x88; + out[75] = (v9x3 | (v9x11 << 4)) ^ 0x88; + out[76] = (v9x4 | (v9x12 << 4)) ^ 0x88; + out[77] = (v9x5 | (v9x13 << 4)) ^ 0x88; + out[78] = (v9x6 | (v9x14 << 4)) ^ 0x88; + out[79] = (v9x7 | (v9x15 << 4)) ^ 0x88; + const uint8_t v10x0 = w10[0] & 0xF; + const uint8_t v10x1 = w10[0] >> 4; + const uint8_t v10x2 = w10[1] & 0xF; + const uint8_t v10x3 = w10[1] >> 4; + const uint8_t v10x4 = w10[2] & 0xF; + const uint8_t v10x5 = w10[2] >> 4; + const uint8_t v10x6 = w10[3] & 0xF; + const uint8_t v10x7 = w10[3] >> 4; + const uint8_t v10x8 = w10[4] & 0xF; + const uint8_t v10x9 = w10[4] >> 4; + const uint8_t v10x10 = w10[5] & 0xF; + const uint8_t v10x11 = w10[5] >> 4; + const uint8_t v10x12 = w10[6] & 0xF; + const uint8_t v10x13 = w10[6] >> 4; + const uint8_t v10x14 = w10[7] & 0xF; + const uint8_t v10x15 = w10[7] >> 4; + w10 += 8; + + ksum10 += (uint32_t) (v10x0); + ksum10 += (uint32_t) (v10x1); + ksum10 += (uint32_t) (v10x2); + ksum10 += (uint32_t) (v10x3); + ksum10 += (uint32_t) (v10x4); + ksum10 += (uint32_t) (v10x5); + ksum10 += (uint32_t) (v10x6); + ksum10 += (uint32_t) (v10x7); + ksum10 += (uint32_t) (v10x8); + ksum10 += (uint32_t) (v10x9); + ksum10 += (uint32_t) (v10x10); + ksum10 += (uint32_t) (v10x11); + ksum10 += (uint32_t) (v10x12); + ksum10 += (uint32_t) (v10x13); + ksum10 += (uint32_t) (v10x14); + ksum10 += (uint32_t) (v10x15); + // Subtract 16 zero points (8) + ksum10 -= 128; + + out[80] = (v10x0 | (v10x8 << 4)) ^ 0x88; + out[81] = (v10x1 | (v10x9 << 4)) ^ 0x88; + out[82] = (v10x2 | (v10x10 << 4)) ^ 0x88; + out[83] = (v10x3 | (v10x11 << 4)) ^ 0x88; + out[84] = (v10x4 | (v10x12 << 4)) ^ 0x88; + out[85] = (v10x5 | (v10x13 << 4)) ^ 0x88; + out[86] = (v10x6 | (v10x14 << 4)) ^ 0x88; + out[87] = (v10x7 | (v10x15 << 4)) ^ 0x88; + const uint8_t v11x0 = w11[0] & 0xF; + const uint8_t v11x1 = w11[0] >> 4; + const uint8_t v11x2 = w11[1] & 0xF; + const uint8_t v11x3 = w11[1] >> 4; + const uint8_t v11x4 = w11[2] & 0xF; + const uint8_t v11x5 = w11[2] >> 4; + const uint8_t v11x6 = w11[3] & 0xF; + const uint8_t v11x7 = w11[3] >> 4; + const uint8_t v11x8 = w11[4] & 0xF; + const uint8_t v11x9 = w11[4] >> 4; + const uint8_t v11x10 = w11[5] & 0xF; + const uint8_t v11x11 = w11[5] >> 4; + const uint8_t v11x12 = w11[6] & 0xF; + const uint8_t v11x13 = w11[6] >> 4; + const uint8_t v11x14 = w11[7] & 0xF; + const uint8_t v11x15 = w11[7] >> 4; + w11 += 8; + + ksum11 += (uint32_t) (v11x0); + ksum11 += (uint32_t) (v11x1); + ksum11 += (uint32_t) (v11x2); + ksum11 += (uint32_t) (v11x3); + ksum11 += (uint32_t) (v11x4); + ksum11 += (uint32_t) (v11x5); + ksum11 += (uint32_t) (v11x6); + ksum11 += (uint32_t) (v11x7); + ksum11 += (uint32_t) (v11x8); + ksum11 += (uint32_t) (v11x9); + ksum11 += (uint32_t) (v11x10); + ksum11 += (uint32_t) (v11x11); + ksum11 += (uint32_t) (v11x12); + ksum11 += (uint32_t) (v11x13); + ksum11 += (uint32_t) (v11x14); + ksum11 += (uint32_t) (v11x15); + // Subtract 16 zero points (8) + ksum11 -= 128; + + out[88] = (v11x0 | (v11x8 << 4)) ^ 0x88; + out[89] = (v11x1 | (v11x9 << 4)) ^ 0x88; + out[90] = (v11x2 | (v11x10 << 4)) ^ 0x88; + out[91] = (v11x3 | (v11x11 << 4)) ^ 0x88; + out[92] = (v11x4 | (v11x12 << 4)) ^ 0x88; + out[93] = (v11x5 | (v11x13 << 4)) ^ 0x88; + out[94] = (v11x6 | (v11x14 << 4)) ^ 0x88; + out[95] = (v11x7 | (v11x15 << 4)) ^ 0x88; + const uint8_t v12x0 = w12[0] & 0xF; + const uint8_t v12x1 = w12[0] >> 4; + const uint8_t v12x2 = w12[1] & 0xF; + const uint8_t v12x3 = w12[1] >> 4; + const uint8_t v12x4 = w12[2] & 0xF; + const uint8_t v12x5 = w12[2] >> 4; + const uint8_t v12x6 = w12[3] & 0xF; + const uint8_t v12x7 = w12[3] >> 4; + const uint8_t v12x8 = w12[4] & 0xF; + const uint8_t v12x9 = w12[4] >> 4; + const uint8_t v12x10 = w12[5] & 0xF; + const uint8_t v12x11 = w12[5] >> 4; + const uint8_t v12x12 = w12[6] & 0xF; + const uint8_t v12x13 = w12[6] >> 4; + const uint8_t v12x14 = w12[7] & 0xF; + const uint8_t v12x15 = w12[7] >> 4; + w12 += 8; + + ksum12 += (uint32_t) (v12x0); + ksum12 += (uint32_t) (v12x1); + ksum12 += (uint32_t) (v12x2); + ksum12 += (uint32_t) (v12x3); + ksum12 += (uint32_t) (v12x4); + ksum12 += (uint32_t) (v12x5); + ksum12 += (uint32_t) (v12x6); + ksum12 += (uint32_t) (v12x7); + ksum12 += (uint32_t) (v12x8); + ksum12 += (uint32_t) (v12x9); + ksum12 += (uint32_t) (v12x10); + ksum12 += (uint32_t) (v12x11); + ksum12 += (uint32_t) (v12x12); + ksum12 += (uint32_t) (v12x13); + ksum12 += (uint32_t) (v12x14); + ksum12 += (uint32_t) (v12x15); + // Subtract 16 zero points (8) + ksum12 -= 128; + + out[96] = (v12x0 | (v12x8 << 4)) ^ 0x88; + out[97] = (v12x1 | (v12x9 << 4)) ^ 0x88; + out[98] = (v12x2 | (v12x10 << 4)) ^ 0x88; + out[99] = (v12x3 | (v12x11 << 4)) ^ 0x88; + out[100] = (v12x4 | (v12x12 << 4)) ^ 0x88; + out[101] = (v12x5 | (v12x13 << 4)) ^ 0x88; + out[102] = (v12x6 | (v12x14 << 4)) ^ 0x88; + out[103] = (v12x7 | (v12x15 << 4)) ^ 0x88; + const uint8_t v13x0 = w13[0] & 0xF; + const uint8_t v13x1 = w13[0] >> 4; + const uint8_t v13x2 = w13[1] & 0xF; + const uint8_t v13x3 = w13[1] >> 4; + const uint8_t v13x4 = w13[2] & 0xF; + const uint8_t v13x5 = w13[2] >> 4; + const uint8_t v13x6 = w13[3] & 0xF; + const uint8_t v13x7 = w13[3] >> 4; + const uint8_t v13x8 = w13[4] & 0xF; + const uint8_t v13x9 = w13[4] >> 4; + const uint8_t v13x10 = w13[5] & 0xF; + const uint8_t v13x11 = w13[5] >> 4; + const uint8_t v13x12 = w13[6] & 0xF; + const uint8_t v13x13 = w13[6] >> 4; + const uint8_t v13x14 = w13[7] & 0xF; + const uint8_t v13x15 = w13[7] >> 4; + w13 += 8; + + ksum13 += (uint32_t) (v13x0); + ksum13 += (uint32_t) (v13x1); + ksum13 += (uint32_t) (v13x2); + ksum13 += (uint32_t) (v13x3); + ksum13 += (uint32_t) (v13x4); + ksum13 += (uint32_t) (v13x5); + ksum13 += (uint32_t) (v13x6); + ksum13 += (uint32_t) (v13x7); + ksum13 += (uint32_t) (v13x8); + ksum13 += (uint32_t) (v13x9); + ksum13 += (uint32_t) (v13x10); + ksum13 += (uint32_t) (v13x11); + ksum13 += (uint32_t) (v13x12); + ksum13 += (uint32_t) (v13x13); + ksum13 += (uint32_t) (v13x14); + ksum13 += (uint32_t) (v13x15); + // Subtract 16 zero points (8) + ksum13 -= 128; + + out[104] = (v13x0 | (v13x8 << 4)) ^ 0x88; + out[105] = (v13x1 | (v13x9 << 4)) ^ 0x88; + out[106] = (v13x2 | (v13x10 << 4)) ^ 0x88; + out[107] = (v13x3 | (v13x11 << 4)) ^ 0x88; + out[108] = (v13x4 | (v13x12 << 4)) ^ 0x88; + out[109] = (v13x5 | (v13x13 << 4)) ^ 0x88; + out[110] = (v13x6 | (v13x14 << 4)) ^ 0x88; + out[111] = (v13x7 | (v13x15 << 4)) ^ 0x88; + const uint8_t v14x0 = w14[0] & 0xF; + const uint8_t v14x1 = w14[0] >> 4; + const uint8_t v14x2 = w14[1] & 0xF; + const uint8_t v14x3 = w14[1] >> 4; + const uint8_t v14x4 = w14[2] & 0xF; + const uint8_t v14x5 = w14[2] >> 4; + const uint8_t v14x6 = w14[3] & 0xF; + const uint8_t v14x7 = w14[3] >> 4; + const uint8_t v14x8 = w14[4] & 0xF; + const uint8_t v14x9 = w14[4] >> 4; + const uint8_t v14x10 = w14[5] & 0xF; + const uint8_t v14x11 = w14[5] >> 4; + const uint8_t v14x12 = w14[6] & 0xF; + const uint8_t v14x13 = w14[6] >> 4; + const uint8_t v14x14 = w14[7] & 0xF; + const uint8_t v14x15 = w14[7] >> 4; + w14 += 8; + + ksum14 += (uint32_t) (v14x0); + ksum14 += (uint32_t) (v14x1); + ksum14 += (uint32_t) (v14x2); + ksum14 += (uint32_t) (v14x3); + ksum14 += (uint32_t) (v14x4); + ksum14 += (uint32_t) (v14x5); + ksum14 += (uint32_t) (v14x6); + ksum14 += (uint32_t) (v14x7); + ksum14 += (uint32_t) (v14x8); + ksum14 += (uint32_t) (v14x9); + ksum14 += (uint32_t) (v14x10); + ksum14 += (uint32_t) (v14x11); + ksum14 += (uint32_t) (v14x12); + ksum14 += (uint32_t) (v14x13); + ksum14 += (uint32_t) (v14x14); + ksum14 += (uint32_t) (v14x15); + // Subtract 16 zero points (8) + ksum14 -= 128; + + out[112] = (v14x0 | (v14x8 << 4)) ^ 0x88; + out[113] = (v14x1 | (v14x9 << 4)) ^ 0x88; + out[114] = (v14x2 | (v14x10 << 4)) ^ 0x88; + out[115] = (v14x3 | (v14x11 << 4)) ^ 0x88; + out[116] = (v14x4 | (v14x12 << 4)) ^ 0x88; + out[117] = (v14x5 | (v14x13 << 4)) ^ 0x88; + out[118] = (v14x6 | (v14x14 << 4)) ^ 0x88; + out[119] = (v14x7 | (v14x15 << 4)) ^ 0x88; + const uint8_t v15x0 = w15[0] & 0xF; + const uint8_t v15x1 = w15[0] >> 4; + const uint8_t v15x2 = w15[1] & 0xF; + const uint8_t v15x3 = w15[1] >> 4; + const uint8_t v15x4 = w15[2] & 0xF; + const uint8_t v15x5 = w15[2] >> 4; + const uint8_t v15x6 = w15[3] & 0xF; + const uint8_t v15x7 = w15[3] >> 4; + const uint8_t v15x8 = w15[4] & 0xF; + const uint8_t v15x9 = w15[4] >> 4; + const uint8_t v15x10 = w15[5] & 0xF; + const uint8_t v15x11 = w15[5] >> 4; + const uint8_t v15x12 = w15[6] & 0xF; + const uint8_t v15x13 = w15[6] >> 4; + const uint8_t v15x14 = w15[7] & 0xF; + const uint8_t v15x15 = w15[7] >> 4; + w15 += 8; + + ksum15 += (uint32_t) (v15x0); + ksum15 += (uint32_t) (v15x1); + ksum15 += (uint32_t) (v15x2); + ksum15 += (uint32_t) (v15x3); + ksum15 += (uint32_t) (v15x4); + ksum15 += (uint32_t) (v15x5); + ksum15 += (uint32_t) (v15x6); + ksum15 += (uint32_t) (v15x7); + ksum15 += (uint32_t) (v15x8); + ksum15 += (uint32_t) (v15x9); + ksum15 += (uint32_t) (v15x10); + ksum15 += (uint32_t) (v15x11); + ksum15 += (uint32_t) (v15x12); + ksum15 += (uint32_t) (v15x13); + ksum15 += (uint32_t) (v15x14); + ksum15 += (uint32_t) (v15x15); + // Subtract 16 zero points (8) + ksum15 -= 128; + + out[120] = (v15x0 | (v15x8 << 4)) ^ 0x88; + out[121] = (v15x1 | (v15x9 << 4)) ^ 0x88; + out[122] = (v15x2 | (v15x10 << 4)) ^ 0x88; + out[123] = (v15x3 | (v15x11 << 4)) ^ 0x88; + out[124] = (v15x4 | (v15x12 << 4)) ^ 0x88; + out[125] = (v15x5 | (v15x13 << 4)) ^ 0x88; + out[126] = (v15x6 | (v15x14 << 4)) ^ 0x88; + out[127] = (v15x7 | (v15x15 << 4)) ^ 0x88; + + out += 128; + } + float scale0 = math_cvt_fp32_bf16(s0[0]); + float scale1 = math_cvt_fp32_bf16(s1[0]); + float scale2 = math_cvt_fp32_bf16(s2[0]); + float scale3 = math_cvt_fp32_bf16(s3[0]); + float scale4 = math_cvt_fp32_bf16(s4[0]); + float scale5 = math_cvt_fp32_bf16(s5[0]); + float scale6 = math_cvt_fp32_bf16(s6[0]); + float scale7 = math_cvt_fp32_bf16(s7[0]); + float scale8 = math_cvt_fp32_bf16(s8[0]); + float scale9 = math_cvt_fp32_bf16(s9[0]); + float scale10 = math_cvt_fp32_bf16(s10[0]); + float scale11 = math_cvt_fp32_bf16(s11[0]); + float scale12 = math_cvt_fp32_bf16(s12[0]); + float scale13 = math_cvt_fp32_bf16(s13[0]); + float scale14 = math_cvt_fp32_bf16(s14[0]); + float scale15 = math_cvt_fp32_bf16(s15[0]); + s0 += 1; + s1 += 1; + s2 += 1; + s3 += 1; + s4 += 1; + s5 += 1; + s6 += 1; + s7 += 1; + s8 += 1; + s9 += 1; + s10 += 1; + s11 += 1; + s12 += 1; + s13 += 1; + s14 += 1; + s15 += 1; + + + packed_k_scaled_sum[0] -= (float)ksum0 * izp * scale0; + packed_k_scaled_sum[1] -= (float)ksum1 * izp * scale1; + packed_k_scaled_sum[2] -= (float)ksum2 * izp * scale2; + packed_k_scaled_sum[3] -= (float)ksum3 * izp * scale3; + packed_k_scaled_sum[4] -= (float)ksum4 * izp * scale4; + packed_k_scaled_sum[5] -= (float)ksum5 * izp * scale5; + packed_k_scaled_sum[6] -= (float)ksum6 * izp * scale6; + packed_k_scaled_sum[7] -= (float)ksum7 * izp * scale7; + packed_k_scaled_sum[8] -= (float)ksum8 * izp * scale8; + packed_k_scaled_sum[9] -= (float)ksum9 * izp * scale9; + packed_k_scaled_sum[10] -= (float)ksum10 * izp * scale10; + packed_k_scaled_sum[11] -= (float)ksum11 * izp * scale11; + packed_k_scaled_sum[12] -= (float)ksum12 * izp * scale12; + packed_k_scaled_sum[13] -= (float)ksum13 * izp * scale13; + packed_k_scaled_sum[14] -= (float)ksum14 * izp * scale14; + packed_k_scaled_sum[15] -= (float)ksum15 * izp * scale15; + + ((uint16_t*) out)[0] = math_cvt_bf16_fp32(scale0 / 16.0f); + ((uint16_t*) out)[1] = math_cvt_bf16_fp32(scale1 / 16.0f); + ((uint16_t*) out)[2] = math_cvt_bf16_fp32(scale2 / 16.0f); + ((uint16_t*) out)[3] = math_cvt_bf16_fp32(scale3 / 16.0f); + ((uint16_t*) out)[4] = math_cvt_bf16_fp32(scale4 / 16.0f); + ((uint16_t*) out)[5] = math_cvt_bf16_fp32(scale5 / 16.0f); + ((uint16_t*) out)[6] = math_cvt_bf16_fp32(scale6 / 16.0f); + ((uint16_t*) out)[7] = math_cvt_bf16_fp32(scale7 / 16.0f); + ((uint16_t*) out)[8] = math_cvt_bf16_fp32(scale8 / 16.0f); + ((uint16_t*) out)[9] = math_cvt_bf16_fp32(scale9 / 16.0f); + ((uint16_t*) out)[10] = math_cvt_bf16_fp32(scale10 / 16.0f); + ((uint16_t*) out)[11] = math_cvt_bf16_fp32(scale11 / 16.0f); + ((uint16_t*) out)[12] = math_cvt_bf16_fp32(scale12 / 16.0f); + ((uint16_t*) out)[13] = math_cvt_bf16_fp32(scale13 / 16.0f); + ((uint16_t*) out)[14] = math_cvt_bf16_fp32(scale14 / 16.0f); + ((uint16_t*) out)[15] = math_cvt_bf16_fp32(scale15 / 16.0f); + + out += 16 * sizeof(uint16_t); + } + + + if XNN_LIKELY(b != NULL){ + ((uint32_t*) out)[0] = b[0]; + ((uint32_t*) out)[1] = b[1]; + ((uint32_t*) out)[2] = b[2]; + ((uint32_t*) out)[3] = b[3]; + ((uint32_t*) out)[4] = b[4]; + ((uint32_t*) out)[5] = b[5]; + ((uint32_t*) out)[6] = b[6]; + ((uint32_t*) out)[7] = b[7]; + ((uint32_t*) out)[8] = b[8]; + ((uint32_t*) out)[9] = b[9]; + ((uint32_t*) out)[10] = b[10]; + ((uint32_t*) out)[11] = b[11]; + ((uint32_t*) out)[12] = b[12]; + ((uint32_t*) out)[13] = b[13]; + ((uint32_t*) out)[14] = b[14]; + ((uint32_t*) out)[15] = b[15]; + b += 16; + } else { + ((uint32_t*) out)[0] = 0; + ((uint32_t*) out)[1] = 0; + ((uint32_t*) out)[2] = 0; + ((uint32_t*) out)[3] = 0; + ((uint32_t*) out)[4] = 0; + ((uint32_t*) out)[5] = 0; + ((uint32_t*) out)[6] = 0; + ((uint32_t*) out)[7] = 0; + ((uint32_t*) out)[8] = 0; + ((uint32_t*) out)[9] = 0; + ((uint32_t*) out)[10] = 0; + ((uint32_t*) out)[11] = 0; + ((uint32_t*) out)[12] = 0; + ((uint32_t*) out)[13] = 0; + ((uint32_t*) out)[14] = 0; + ((uint32_t*) out)[15] = 0; + } + out += 16 * sizeof(uint32_t); + w0 = w15; + s0 = s15; + } + + // NC remainder (1..15) + if XNN_UNLIKELY(n != 0) { + float* packed_k_scaled_sum = (float*) out; + ((float*) out)[0] = 0; + ((float*) out)[1] = 0; + ((float*) out)[2] = 0; + ((float*) out)[3] = 0; + ((float*) out)[4] = 0; + ((float*) out)[5] = 0; + ((float*) out)[6] = 0; + ((float*) out)[7] = 0; + ((float*) out)[8] = 0; + ((float*) out)[9] = 0; + ((float*) out)[10] = 0; + ((float*) out)[11] = 0; + ((float*) out)[12] = 0; + ((float*) out)[13] = 0; + ((float*) out)[14] = 0; + ((float*) out)[15] = 0; + out += 16 * sizeof(float); + // NR remainder has less than 16 + const uint8_t* w1 = w0 + (kc >> 1); + const uint16_t* s1 = s0 + num_blocks; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + s1 = s0; + } + const uint8_t* w2 = w1 + (kc >> 1); + const uint16_t* s2 = s1 + num_blocks; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + s2 = s1; + } + const uint8_t* w3 = w2 + (kc >> 1); + const uint16_t* s3 = s2 + num_blocks; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + s3 = s2; + } + const uint8_t* w4 = w3 + (kc >> 1); + const uint16_t* s4 = s3 + num_blocks; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + s4 = s3; + } + const uint8_t* w5 = w4 + (kc >> 1); + const uint16_t* s5 = s4 + num_blocks; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + s5 = s4; + } + const uint8_t* w6 = w5 + (kc >> 1); + const uint16_t* s6 = s5 + num_blocks; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + s6 = s5; + } + const uint8_t* w7 = w6 + (kc >> 1); + const uint16_t* s7 = s6 + num_blocks; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + s7 = s6; + } + const uint8_t* w8 = w7 + (kc >> 1); + const uint16_t* s8 = s7 + num_blocks; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + s8 = s7; + } + const uint8_t* w9 = w8 + (kc >> 1); + const uint16_t* s9 = s8 + num_blocks; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + s9 = s8; + } + const uint8_t* w10 = w9 + (kc >> 1); + const uint16_t* s10 = s9 + num_blocks; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + s10 = s9; + } + const uint8_t* w11 = w10 + (kc >> 1); + const uint16_t* s11 = s10 + num_blocks; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + s11 = s10; + } + const uint8_t* w12 = w11 + (kc >> 1); + const uint16_t* s12 = s11 + num_blocks; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + s12 = s11; + } + const uint8_t* w13 = w12 + (kc >> 1); + const uint16_t* s13 = s12 + num_blocks; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + s13 = s12; + } + const uint8_t* w14 = w13 + (kc >> 1); + const uint16_t* s14 = s13 + num_blocks; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + s14 = s13; + } + + size_t kb = kc; + // Process k by blocks (bl) + for (; kb >= bl; kb-=bl) { + // Initialize KSum as subtracting bl zero points (8) + int32_t ksum0 = 0; + int32_t ksum1 = 0; + int32_t ksum2 = 0; + int32_t ksum3 = 0; + int32_t ksum4 = 0; + int32_t ksum5 = 0; + int32_t ksum6 = 0; + int32_t ksum7 = 0; + int32_t ksum8 = 0; + int32_t ksum9 = 0; + int32_t ksum10 = 0; + int32_t ksum11 = 0; + int32_t ksum12 = 0; + int32_t ksum13 = 0; + int32_t ksum14 = 0; + size_t k = bl; + for(; k >= 16; k-=16) { + const uint8_t v0x0 = w0[0] & 0xF; + const uint8_t v0x1 = w0[0] >> 4; + const uint8_t v0x2 = w0[1] & 0xF; + const uint8_t v0x3 = w0[1] >> 4; + const uint8_t v0x4 = w0[2] & 0xF; + const uint8_t v0x5 = w0[2] >> 4; + const uint8_t v0x6 = w0[3] & 0xF; + const uint8_t v0x7 = w0[3] >> 4; + const uint8_t v0x8 = w0[4] & 0xF; + const uint8_t v0x9 = w0[4] >> 4; + const uint8_t v0x10 = w0[5] & 0xF; + const uint8_t v0x11 = w0[5] >> 4; + const uint8_t v0x12 = w0[6] & 0xF; + const uint8_t v0x13 = w0[6] >> 4; + const uint8_t v0x14 = w0[7] & 0xF; + const uint8_t v0x15 = w0[7] >> 4; + w0 += 8; + + ksum0 += (uint32_t) (v0x0); + ksum0 += (uint32_t) (v0x1); + ksum0 += (uint32_t) (v0x2); + ksum0 += (uint32_t) (v0x3); + ksum0 += (uint32_t) (v0x4); + ksum0 += (uint32_t) (v0x5); + ksum0 += (uint32_t) (v0x6); + ksum0 += (uint32_t) (v0x7); + ksum0 += (uint32_t) (v0x8); + ksum0 += (uint32_t) (v0x9); + ksum0 += (uint32_t) (v0x10); + ksum0 += (uint32_t) (v0x11); + ksum0 += (uint32_t) (v0x12); + ksum0 += (uint32_t) (v0x13); + ksum0 += (uint32_t) (v0x14); + ksum0 += (uint32_t) (v0x15); + // Subtract 16 zero points (8) + ksum0 -= 128; + + out[0] = (v0x0 | (v0x8 << 4)) ^ 0x88; + out[1] = (v0x1 | (v0x9 << 4)) ^ 0x88; + out[2] = (v0x2 | (v0x10 << 4)) ^ 0x88; + out[3] = (v0x3 | (v0x11 << 4)) ^ 0x88; + out[4] = (v0x4 | (v0x12 << 4)) ^ 0x88; + out[5] = (v0x5 | (v0x13 << 4)) ^ 0x88; + out[6] = (v0x6 | (v0x14 << 4)) ^ 0x88; + out[7] = (v0x7 | (v0x15 << 4)) ^ 0x88; + const uint8_t v1x0 = w1[0] & 0xF; + const uint8_t v1x1 = w1[0] >> 4; + const uint8_t v1x2 = w1[1] & 0xF; + const uint8_t v1x3 = w1[1] >> 4; + const uint8_t v1x4 = w1[2] & 0xF; + const uint8_t v1x5 = w1[2] >> 4; + const uint8_t v1x6 = w1[3] & 0xF; + const uint8_t v1x7 = w1[3] >> 4; + const uint8_t v1x8 = w1[4] & 0xF; + const uint8_t v1x9 = w1[4] >> 4; + const uint8_t v1x10 = w1[5] & 0xF; + const uint8_t v1x11 = w1[5] >> 4; + const uint8_t v1x12 = w1[6] & 0xF; + const uint8_t v1x13 = w1[6] >> 4; + const uint8_t v1x14 = w1[7] & 0xF; + const uint8_t v1x15 = w1[7] >> 4; + w1 += 8; + + ksum1 += (uint32_t) (v1x0); + ksum1 += (uint32_t) (v1x1); + ksum1 += (uint32_t) (v1x2); + ksum1 += (uint32_t) (v1x3); + ksum1 += (uint32_t) (v1x4); + ksum1 += (uint32_t) (v1x5); + ksum1 += (uint32_t) (v1x6); + ksum1 += (uint32_t) (v1x7); + ksum1 += (uint32_t) (v1x8); + ksum1 += (uint32_t) (v1x9); + ksum1 += (uint32_t) (v1x10); + ksum1 += (uint32_t) (v1x11); + ksum1 += (uint32_t) (v1x12); + ksum1 += (uint32_t) (v1x13); + ksum1 += (uint32_t) (v1x14); + ksum1 += (uint32_t) (v1x15); + // Subtract 16 zero points (8) + ksum1 -= 128; + + out[8] = (v1x0 | (v1x8 << 4)) ^ 0x88; + out[9] = (v1x1 | (v1x9 << 4)) ^ 0x88; + out[10] = (v1x2 | (v1x10 << 4)) ^ 0x88; + out[11] = (v1x3 | (v1x11 << 4)) ^ 0x88; + out[12] = (v1x4 | (v1x12 << 4)) ^ 0x88; + out[13] = (v1x5 | (v1x13 << 4)) ^ 0x88; + out[14] = (v1x6 | (v1x14 << 4)) ^ 0x88; + out[15] = (v1x7 | (v1x15 << 4)) ^ 0x88; + const uint8_t v2x0 = w2[0] & 0xF; + const uint8_t v2x1 = w2[0] >> 4; + const uint8_t v2x2 = w2[1] & 0xF; + const uint8_t v2x3 = w2[1] >> 4; + const uint8_t v2x4 = w2[2] & 0xF; + const uint8_t v2x5 = w2[2] >> 4; + const uint8_t v2x6 = w2[3] & 0xF; + const uint8_t v2x7 = w2[3] >> 4; + const uint8_t v2x8 = w2[4] & 0xF; + const uint8_t v2x9 = w2[4] >> 4; + const uint8_t v2x10 = w2[5] & 0xF; + const uint8_t v2x11 = w2[5] >> 4; + const uint8_t v2x12 = w2[6] & 0xF; + const uint8_t v2x13 = w2[6] >> 4; + const uint8_t v2x14 = w2[7] & 0xF; + const uint8_t v2x15 = w2[7] >> 4; + w2 += 8; + + ksum2 += (uint32_t) (v2x0); + ksum2 += (uint32_t) (v2x1); + ksum2 += (uint32_t) (v2x2); + ksum2 += (uint32_t) (v2x3); + ksum2 += (uint32_t) (v2x4); + ksum2 += (uint32_t) (v2x5); + ksum2 += (uint32_t) (v2x6); + ksum2 += (uint32_t) (v2x7); + ksum2 += (uint32_t) (v2x8); + ksum2 += (uint32_t) (v2x9); + ksum2 += (uint32_t) (v2x10); + ksum2 += (uint32_t) (v2x11); + ksum2 += (uint32_t) (v2x12); + ksum2 += (uint32_t) (v2x13); + ksum2 += (uint32_t) (v2x14); + ksum2 += (uint32_t) (v2x15); + // Subtract 16 zero points (8) + ksum2 -= 128; + + out[16] = (v2x0 | (v2x8 << 4)) ^ 0x88; + out[17] = (v2x1 | (v2x9 << 4)) ^ 0x88; + out[18] = (v2x2 | (v2x10 << 4)) ^ 0x88; + out[19] = (v2x3 | (v2x11 << 4)) ^ 0x88; + out[20] = (v2x4 | (v2x12 << 4)) ^ 0x88; + out[21] = (v2x5 | (v2x13 << 4)) ^ 0x88; + out[22] = (v2x6 | (v2x14 << 4)) ^ 0x88; + out[23] = (v2x7 | (v2x15 << 4)) ^ 0x88; + const uint8_t v3x0 = w3[0] & 0xF; + const uint8_t v3x1 = w3[0] >> 4; + const uint8_t v3x2 = w3[1] & 0xF; + const uint8_t v3x3 = w3[1] >> 4; + const uint8_t v3x4 = w3[2] & 0xF; + const uint8_t v3x5 = w3[2] >> 4; + const uint8_t v3x6 = w3[3] & 0xF; + const uint8_t v3x7 = w3[3] >> 4; + const uint8_t v3x8 = w3[4] & 0xF; + const uint8_t v3x9 = w3[4] >> 4; + const uint8_t v3x10 = w3[5] & 0xF; + const uint8_t v3x11 = w3[5] >> 4; + const uint8_t v3x12 = w3[6] & 0xF; + const uint8_t v3x13 = w3[6] >> 4; + const uint8_t v3x14 = w3[7] & 0xF; + const uint8_t v3x15 = w3[7] >> 4; + w3 += 8; + + ksum3 += (uint32_t) (v3x0); + ksum3 += (uint32_t) (v3x1); + ksum3 += (uint32_t) (v3x2); + ksum3 += (uint32_t) (v3x3); + ksum3 += (uint32_t) (v3x4); + ksum3 += (uint32_t) (v3x5); + ksum3 += (uint32_t) (v3x6); + ksum3 += (uint32_t) (v3x7); + ksum3 += (uint32_t) (v3x8); + ksum3 += (uint32_t) (v3x9); + ksum3 += (uint32_t) (v3x10); + ksum3 += (uint32_t) (v3x11); + ksum3 += (uint32_t) (v3x12); + ksum3 += (uint32_t) (v3x13); + ksum3 += (uint32_t) (v3x14); + ksum3 += (uint32_t) (v3x15); + // Subtract 16 zero points (8) + ksum3 -= 128; + + out[24] = (v3x0 | (v3x8 << 4)) ^ 0x88; + out[25] = (v3x1 | (v3x9 << 4)) ^ 0x88; + out[26] = (v3x2 | (v3x10 << 4)) ^ 0x88; + out[27] = (v3x3 | (v3x11 << 4)) ^ 0x88; + out[28] = (v3x4 | (v3x12 << 4)) ^ 0x88; + out[29] = (v3x5 | (v3x13 << 4)) ^ 0x88; + out[30] = (v3x6 | (v3x14 << 4)) ^ 0x88; + out[31] = (v3x7 | (v3x15 << 4)) ^ 0x88; + const uint8_t v4x0 = w4[0] & 0xF; + const uint8_t v4x1 = w4[0] >> 4; + const uint8_t v4x2 = w4[1] & 0xF; + const uint8_t v4x3 = w4[1] >> 4; + const uint8_t v4x4 = w4[2] & 0xF; + const uint8_t v4x5 = w4[2] >> 4; + const uint8_t v4x6 = w4[3] & 0xF; + const uint8_t v4x7 = w4[3] >> 4; + const uint8_t v4x8 = w4[4] & 0xF; + const uint8_t v4x9 = w4[4] >> 4; + const uint8_t v4x10 = w4[5] & 0xF; + const uint8_t v4x11 = w4[5] >> 4; + const uint8_t v4x12 = w4[6] & 0xF; + const uint8_t v4x13 = w4[6] >> 4; + const uint8_t v4x14 = w4[7] & 0xF; + const uint8_t v4x15 = w4[7] >> 4; + w4 += 8; + + ksum4 += (uint32_t) (v4x0); + ksum4 += (uint32_t) (v4x1); + ksum4 += (uint32_t) (v4x2); + ksum4 += (uint32_t) (v4x3); + ksum4 += (uint32_t) (v4x4); + ksum4 += (uint32_t) (v4x5); + ksum4 += (uint32_t) (v4x6); + ksum4 += (uint32_t) (v4x7); + ksum4 += (uint32_t) (v4x8); + ksum4 += (uint32_t) (v4x9); + ksum4 += (uint32_t) (v4x10); + ksum4 += (uint32_t) (v4x11); + ksum4 += (uint32_t) (v4x12); + ksum4 += (uint32_t) (v4x13); + ksum4 += (uint32_t) (v4x14); + ksum4 += (uint32_t) (v4x15); + // Subtract 16 zero points (8) + ksum4 -= 128; + + out[32] = (v4x0 | (v4x8 << 4)) ^ 0x88; + out[33] = (v4x1 | (v4x9 << 4)) ^ 0x88; + out[34] = (v4x2 | (v4x10 << 4)) ^ 0x88; + out[35] = (v4x3 | (v4x11 << 4)) ^ 0x88; + out[36] = (v4x4 | (v4x12 << 4)) ^ 0x88; + out[37] = (v4x5 | (v4x13 << 4)) ^ 0x88; + out[38] = (v4x6 | (v4x14 << 4)) ^ 0x88; + out[39] = (v4x7 | (v4x15 << 4)) ^ 0x88; + const uint8_t v5x0 = w5[0] & 0xF; + const uint8_t v5x1 = w5[0] >> 4; + const uint8_t v5x2 = w5[1] & 0xF; + const uint8_t v5x3 = w5[1] >> 4; + const uint8_t v5x4 = w5[2] & 0xF; + const uint8_t v5x5 = w5[2] >> 4; + const uint8_t v5x6 = w5[3] & 0xF; + const uint8_t v5x7 = w5[3] >> 4; + const uint8_t v5x8 = w5[4] & 0xF; + const uint8_t v5x9 = w5[4] >> 4; + const uint8_t v5x10 = w5[5] & 0xF; + const uint8_t v5x11 = w5[5] >> 4; + const uint8_t v5x12 = w5[6] & 0xF; + const uint8_t v5x13 = w5[6] >> 4; + const uint8_t v5x14 = w5[7] & 0xF; + const uint8_t v5x15 = w5[7] >> 4; + w5 += 8; + + ksum5 += (uint32_t) (v5x0); + ksum5 += (uint32_t) (v5x1); + ksum5 += (uint32_t) (v5x2); + ksum5 += (uint32_t) (v5x3); + ksum5 += (uint32_t) (v5x4); + ksum5 += (uint32_t) (v5x5); + ksum5 += (uint32_t) (v5x6); + ksum5 += (uint32_t) (v5x7); + ksum5 += (uint32_t) (v5x8); + ksum5 += (uint32_t) (v5x9); + ksum5 += (uint32_t) (v5x10); + ksum5 += (uint32_t) (v5x11); + ksum5 += (uint32_t) (v5x12); + ksum5 += (uint32_t) (v5x13); + ksum5 += (uint32_t) (v5x14); + ksum5 += (uint32_t) (v5x15); + // Subtract 16 zero points (8) + ksum5 -= 128; + + out[40] = (v5x0 | (v5x8 << 4)) ^ 0x88; + out[41] = (v5x1 | (v5x9 << 4)) ^ 0x88; + out[42] = (v5x2 | (v5x10 << 4)) ^ 0x88; + out[43] = (v5x3 | (v5x11 << 4)) ^ 0x88; + out[44] = (v5x4 | (v5x12 << 4)) ^ 0x88; + out[45] = (v5x5 | (v5x13 << 4)) ^ 0x88; + out[46] = (v5x6 | (v5x14 << 4)) ^ 0x88; + out[47] = (v5x7 | (v5x15 << 4)) ^ 0x88; + const uint8_t v6x0 = w6[0] & 0xF; + const uint8_t v6x1 = w6[0] >> 4; + const uint8_t v6x2 = w6[1] & 0xF; + const uint8_t v6x3 = w6[1] >> 4; + const uint8_t v6x4 = w6[2] & 0xF; + const uint8_t v6x5 = w6[2] >> 4; + const uint8_t v6x6 = w6[3] & 0xF; + const uint8_t v6x7 = w6[3] >> 4; + const uint8_t v6x8 = w6[4] & 0xF; + const uint8_t v6x9 = w6[4] >> 4; + const uint8_t v6x10 = w6[5] & 0xF; + const uint8_t v6x11 = w6[5] >> 4; + const uint8_t v6x12 = w6[6] & 0xF; + const uint8_t v6x13 = w6[6] >> 4; + const uint8_t v6x14 = w6[7] & 0xF; + const uint8_t v6x15 = w6[7] >> 4; + w6 += 8; + + ksum6 += (uint32_t) (v6x0); + ksum6 += (uint32_t) (v6x1); + ksum6 += (uint32_t) (v6x2); + ksum6 += (uint32_t) (v6x3); + ksum6 += (uint32_t) (v6x4); + ksum6 += (uint32_t) (v6x5); + ksum6 += (uint32_t) (v6x6); + ksum6 += (uint32_t) (v6x7); + ksum6 += (uint32_t) (v6x8); + ksum6 += (uint32_t) (v6x9); + ksum6 += (uint32_t) (v6x10); + ksum6 += (uint32_t) (v6x11); + ksum6 += (uint32_t) (v6x12); + ksum6 += (uint32_t) (v6x13); + ksum6 += (uint32_t) (v6x14); + ksum6 += (uint32_t) (v6x15); + // Subtract 16 zero points (8) + ksum6 -= 128; + + out[48] = (v6x0 | (v6x8 << 4)) ^ 0x88; + out[49] = (v6x1 | (v6x9 << 4)) ^ 0x88; + out[50] = (v6x2 | (v6x10 << 4)) ^ 0x88; + out[51] = (v6x3 | (v6x11 << 4)) ^ 0x88; + out[52] = (v6x4 | (v6x12 << 4)) ^ 0x88; + out[53] = (v6x5 | (v6x13 << 4)) ^ 0x88; + out[54] = (v6x6 | (v6x14 << 4)) ^ 0x88; + out[55] = (v6x7 | (v6x15 << 4)) ^ 0x88; + const uint8_t v7x0 = w7[0] & 0xF; + const uint8_t v7x1 = w7[0] >> 4; + const uint8_t v7x2 = w7[1] & 0xF; + const uint8_t v7x3 = w7[1] >> 4; + const uint8_t v7x4 = w7[2] & 0xF; + const uint8_t v7x5 = w7[2] >> 4; + const uint8_t v7x6 = w7[3] & 0xF; + const uint8_t v7x7 = w7[3] >> 4; + const uint8_t v7x8 = w7[4] & 0xF; + const uint8_t v7x9 = w7[4] >> 4; + const uint8_t v7x10 = w7[5] & 0xF; + const uint8_t v7x11 = w7[5] >> 4; + const uint8_t v7x12 = w7[6] & 0xF; + const uint8_t v7x13 = w7[6] >> 4; + const uint8_t v7x14 = w7[7] & 0xF; + const uint8_t v7x15 = w7[7] >> 4; + w7 += 8; + + ksum7 += (uint32_t) (v7x0); + ksum7 += (uint32_t) (v7x1); + ksum7 += (uint32_t) (v7x2); + ksum7 += (uint32_t) (v7x3); + ksum7 += (uint32_t) (v7x4); + ksum7 += (uint32_t) (v7x5); + ksum7 += (uint32_t) (v7x6); + ksum7 += (uint32_t) (v7x7); + ksum7 += (uint32_t) (v7x8); + ksum7 += (uint32_t) (v7x9); + ksum7 += (uint32_t) (v7x10); + ksum7 += (uint32_t) (v7x11); + ksum7 += (uint32_t) (v7x12); + ksum7 += (uint32_t) (v7x13); + ksum7 += (uint32_t) (v7x14); + ksum7 += (uint32_t) (v7x15); + // Subtract 16 zero points (8) + ksum7 -= 128; + + out[56] = (v7x0 | (v7x8 << 4)) ^ 0x88; + out[57] = (v7x1 | (v7x9 << 4)) ^ 0x88; + out[58] = (v7x2 | (v7x10 << 4)) ^ 0x88; + out[59] = (v7x3 | (v7x11 << 4)) ^ 0x88; + out[60] = (v7x4 | (v7x12 << 4)) ^ 0x88; + out[61] = (v7x5 | (v7x13 << 4)) ^ 0x88; + out[62] = (v7x6 | (v7x14 << 4)) ^ 0x88; + out[63] = (v7x7 | (v7x15 << 4)) ^ 0x88; + const uint8_t v8x0 = w8[0] & 0xF; + const uint8_t v8x1 = w8[0] >> 4; + const uint8_t v8x2 = w8[1] & 0xF; + const uint8_t v8x3 = w8[1] >> 4; + const uint8_t v8x4 = w8[2] & 0xF; + const uint8_t v8x5 = w8[2] >> 4; + const uint8_t v8x6 = w8[3] & 0xF; + const uint8_t v8x7 = w8[3] >> 4; + const uint8_t v8x8 = w8[4] & 0xF; + const uint8_t v8x9 = w8[4] >> 4; + const uint8_t v8x10 = w8[5] & 0xF; + const uint8_t v8x11 = w8[5] >> 4; + const uint8_t v8x12 = w8[6] & 0xF; + const uint8_t v8x13 = w8[6] >> 4; + const uint8_t v8x14 = w8[7] & 0xF; + const uint8_t v8x15 = w8[7] >> 4; + w8 += 8; + + ksum8 += (uint32_t) (v8x0); + ksum8 += (uint32_t) (v8x1); + ksum8 += (uint32_t) (v8x2); + ksum8 += (uint32_t) (v8x3); + ksum8 += (uint32_t) (v8x4); + ksum8 += (uint32_t) (v8x5); + ksum8 += (uint32_t) (v8x6); + ksum8 += (uint32_t) (v8x7); + ksum8 += (uint32_t) (v8x8); + ksum8 += (uint32_t) (v8x9); + ksum8 += (uint32_t) (v8x10); + ksum8 += (uint32_t) (v8x11); + ksum8 += (uint32_t) (v8x12); + ksum8 += (uint32_t) (v8x13); + ksum8 += (uint32_t) (v8x14); + ksum8 += (uint32_t) (v8x15); + // Subtract 16 zero points (8) + ksum8 -= 128; + + out[64] = (v8x0 | (v8x8 << 4)) ^ 0x88; + out[65] = (v8x1 | (v8x9 << 4)) ^ 0x88; + out[66] = (v8x2 | (v8x10 << 4)) ^ 0x88; + out[67] = (v8x3 | (v8x11 << 4)) ^ 0x88; + out[68] = (v8x4 | (v8x12 << 4)) ^ 0x88; + out[69] = (v8x5 | (v8x13 << 4)) ^ 0x88; + out[70] = (v8x6 | (v8x14 << 4)) ^ 0x88; + out[71] = (v8x7 | (v8x15 << 4)) ^ 0x88; + const uint8_t v9x0 = w9[0] & 0xF; + const uint8_t v9x1 = w9[0] >> 4; + const uint8_t v9x2 = w9[1] & 0xF; + const uint8_t v9x3 = w9[1] >> 4; + const uint8_t v9x4 = w9[2] & 0xF; + const uint8_t v9x5 = w9[2] >> 4; + const uint8_t v9x6 = w9[3] & 0xF; + const uint8_t v9x7 = w9[3] >> 4; + const uint8_t v9x8 = w9[4] & 0xF; + const uint8_t v9x9 = w9[4] >> 4; + const uint8_t v9x10 = w9[5] & 0xF; + const uint8_t v9x11 = w9[5] >> 4; + const uint8_t v9x12 = w9[6] & 0xF; + const uint8_t v9x13 = w9[6] >> 4; + const uint8_t v9x14 = w9[7] & 0xF; + const uint8_t v9x15 = w9[7] >> 4; + w9 += 8; + + ksum9 += (uint32_t) (v9x0); + ksum9 += (uint32_t) (v9x1); + ksum9 += (uint32_t) (v9x2); + ksum9 += (uint32_t) (v9x3); + ksum9 += (uint32_t) (v9x4); + ksum9 += (uint32_t) (v9x5); + ksum9 += (uint32_t) (v9x6); + ksum9 += (uint32_t) (v9x7); + ksum9 += (uint32_t) (v9x8); + ksum9 += (uint32_t) (v9x9); + ksum9 += (uint32_t) (v9x10); + ksum9 += (uint32_t) (v9x11); + ksum9 += (uint32_t) (v9x12); + ksum9 += (uint32_t) (v9x13); + ksum9 += (uint32_t) (v9x14); + ksum9 += (uint32_t) (v9x15); + // Subtract 16 zero points (8) + ksum9 -= 128; + + out[72] = (v9x0 | (v9x8 << 4)) ^ 0x88; + out[73] = (v9x1 | (v9x9 << 4)) ^ 0x88; + out[74] = (v9x2 | (v9x10 << 4)) ^ 0x88; + out[75] = (v9x3 | (v9x11 << 4)) ^ 0x88; + out[76] = (v9x4 | (v9x12 << 4)) ^ 0x88; + out[77] = (v9x5 | (v9x13 << 4)) ^ 0x88; + out[78] = (v9x6 | (v9x14 << 4)) ^ 0x88; + out[79] = (v9x7 | (v9x15 << 4)) ^ 0x88; + const uint8_t v10x0 = w10[0] & 0xF; + const uint8_t v10x1 = w10[0] >> 4; + const uint8_t v10x2 = w10[1] & 0xF; + const uint8_t v10x3 = w10[1] >> 4; + const uint8_t v10x4 = w10[2] & 0xF; + const uint8_t v10x5 = w10[2] >> 4; + const uint8_t v10x6 = w10[3] & 0xF; + const uint8_t v10x7 = w10[3] >> 4; + const uint8_t v10x8 = w10[4] & 0xF; + const uint8_t v10x9 = w10[4] >> 4; + const uint8_t v10x10 = w10[5] & 0xF; + const uint8_t v10x11 = w10[5] >> 4; + const uint8_t v10x12 = w10[6] & 0xF; + const uint8_t v10x13 = w10[6] >> 4; + const uint8_t v10x14 = w10[7] & 0xF; + const uint8_t v10x15 = w10[7] >> 4; + w10 += 8; + + ksum10 += (uint32_t) (v10x0); + ksum10 += (uint32_t) (v10x1); + ksum10 += (uint32_t) (v10x2); + ksum10 += (uint32_t) (v10x3); + ksum10 += (uint32_t) (v10x4); + ksum10 += (uint32_t) (v10x5); + ksum10 += (uint32_t) (v10x6); + ksum10 += (uint32_t) (v10x7); + ksum10 += (uint32_t) (v10x8); + ksum10 += (uint32_t) (v10x9); + ksum10 += (uint32_t) (v10x10); + ksum10 += (uint32_t) (v10x11); + ksum10 += (uint32_t) (v10x12); + ksum10 += (uint32_t) (v10x13); + ksum10 += (uint32_t) (v10x14); + ksum10 += (uint32_t) (v10x15); + // Subtract 16 zero points (8) + ksum10 -= 128; + + out[80] = (v10x0 | (v10x8 << 4)) ^ 0x88; + out[81] = (v10x1 | (v10x9 << 4)) ^ 0x88; + out[82] = (v10x2 | (v10x10 << 4)) ^ 0x88; + out[83] = (v10x3 | (v10x11 << 4)) ^ 0x88; + out[84] = (v10x4 | (v10x12 << 4)) ^ 0x88; + out[85] = (v10x5 | (v10x13 << 4)) ^ 0x88; + out[86] = (v10x6 | (v10x14 << 4)) ^ 0x88; + out[87] = (v10x7 | (v10x15 << 4)) ^ 0x88; + const uint8_t v11x0 = w11[0] & 0xF; + const uint8_t v11x1 = w11[0] >> 4; + const uint8_t v11x2 = w11[1] & 0xF; + const uint8_t v11x3 = w11[1] >> 4; + const uint8_t v11x4 = w11[2] & 0xF; + const uint8_t v11x5 = w11[2] >> 4; + const uint8_t v11x6 = w11[3] & 0xF; + const uint8_t v11x7 = w11[3] >> 4; + const uint8_t v11x8 = w11[4] & 0xF; + const uint8_t v11x9 = w11[4] >> 4; + const uint8_t v11x10 = w11[5] & 0xF; + const uint8_t v11x11 = w11[5] >> 4; + const uint8_t v11x12 = w11[6] & 0xF; + const uint8_t v11x13 = w11[6] >> 4; + const uint8_t v11x14 = w11[7] & 0xF; + const uint8_t v11x15 = w11[7] >> 4; + w11 += 8; + + ksum11 += (uint32_t) (v11x0); + ksum11 += (uint32_t) (v11x1); + ksum11 += (uint32_t) (v11x2); + ksum11 += (uint32_t) (v11x3); + ksum11 += (uint32_t) (v11x4); + ksum11 += (uint32_t) (v11x5); + ksum11 += (uint32_t) (v11x6); + ksum11 += (uint32_t) (v11x7); + ksum11 += (uint32_t) (v11x8); + ksum11 += (uint32_t) (v11x9); + ksum11 += (uint32_t) (v11x10); + ksum11 += (uint32_t) (v11x11); + ksum11 += (uint32_t) (v11x12); + ksum11 += (uint32_t) (v11x13); + ksum11 += (uint32_t) (v11x14); + ksum11 += (uint32_t) (v11x15); + // Subtract 16 zero points (8) + ksum11 -= 128; + + out[88] = (v11x0 | (v11x8 << 4)) ^ 0x88; + out[89] = (v11x1 | (v11x9 << 4)) ^ 0x88; + out[90] = (v11x2 | (v11x10 << 4)) ^ 0x88; + out[91] = (v11x3 | (v11x11 << 4)) ^ 0x88; + out[92] = (v11x4 | (v11x12 << 4)) ^ 0x88; + out[93] = (v11x5 | (v11x13 << 4)) ^ 0x88; + out[94] = (v11x6 | (v11x14 << 4)) ^ 0x88; + out[95] = (v11x7 | (v11x15 << 4)) ^ 0x88; + const uint8_t v12x0 = w12[0] & 0xF; + const uint8_t v12x1 = w12[0] >> 4; + const uint8_t v12x2 = w12[1] & 0xF; + const uint8_t v12x3 = w12[1] >> 4; + const uint8_t v12x4 = w12[2] & 0xF; + const uint8_t v12x5 = w12[2] >> 4; + const uint8_t v12x6 = w12[3] & 0xF; + const uint8_t v12x7 = w12[3] >> 4; + const uint8_t v12x8 = w12[4] & 0xF; + const uint8_t v12x9 = w12[4] >> 4; + const uint8_t v12x10 = w12[5] & 0xF; + const uint8_t v12x11 = w12[5] >> 4; + const uint8_t v12x12 = w12[6] & 0xF; + const uint8_t v12x13 = w12[6] >> 4; + const uint8_t v12x14 = w12[7] & 0xF; + const uint8_t v12x15 = w12[7] >> 4; + w12 += 8; + + ksum12 += (uint32_t) (v12x0); + ksum12 += (uint32_t) (v12x1); + ksum12 += (uint32_t) (v12x2); + ksum12 += (uint32_t) (v12x3); + ksum12 += (uint32_t) (v12x4); + ksum12 += (uint32_t) (v12x5); + ksum12 += (uint32_t) (v12x6); + ksum12 += (uint32_t) (v12x7); + ksum12 += (uint32_t) (v12x8); + ksum12 += (uint32_t) (v12x9); + ksum12 += (uint32_t) (v12x10); + ksum12 += (uint32_t) (v12x11); + ksum12 += (uint32_t) (v12x12); + ksum12 += (uint32_t) (v12x13); + ksum12 += (uint32_t) (v12x14); + ksum12 += (uint32_t) (v12x15); + // Subtract 16 zero points (8) + ksum12 -= 128; + + out[96] = (v12x0 | (v12x8 << 4)) ^ 0x88; + out[97] = (v12x1 | (v12x9 << 4)) ^ 0x88; + out[98] = (v12x2 | (v12x10 << 4)) ^ 0x88; + out[99] = (v12x3 | (v12x11 << 4)) ^ 0x88; + out[100] = (v12x4 | (v12x12 << 4)) ^ 0x88; + out[101] = (v12x5 | (v12x13 << 4)) ^ 0x88; + out[102] = (v12x6 | (v12x14 << 4)) ^ 0x88; + out[103] = (v12x7 | (v12x15 << 4)) ^ 0x88; + const uint8_t v13x0 = w13[0] & 0xF; + const uint8_t v13x1 = w13[0] >> 4; + const uint8_t v13x2 = w13[1] & 0xF; + const uint8_t v13x3 = w13[1] >> 4; + const uint8_t v13x4 = w13[2] & 0xF; + const uint8_t v13x5 = w13[2] >> 4; + const uint8_t v13x6 = w13[3] & 0xF; + const uint8_t v13x7 = w13[3] >> 4; + const uint8_t v13x8 = w13[4] & 0xF; + const uint8_t v13x9 = w13[4] >> 4; + const uint8_t v13x10 = w13[5] & 0xF; + const uint8_t v13x11 = w13[5] >> 4; + const uint8_t v13x12 = w13[6] & 0xF; + const uint8_t v13x13 = w13[6] >> 4; + const uint8_t v13x14 = w13[7] & 0xF; + const uint8_t v13x15 = w13[7] >> 4; + w13 += 8; + + ksum13 += (uint32_t) (v13x0); + ksum13 += (uint32_t) (v13x1); + ksum13 += (uint32_t) (v13x2); + ksum13 += (uint32_t) (v13x3); + ksum13 += (uint32_t) (v13x4); + ksum13 += (uint32_t) (v13x5); + ksum13 += (uint32_t) (v13x6); + ksum13 += (uint32_t) (v13x7); + ksum13 += (uint32_t) (v13x8); + ksum13 += (uint32_t) (v13x9); + ksum13 += (uint32_t) (v13x10); + ksum13 += (uint32_t) (v13x11); + ksum13 += (uint32_t) (v13x12); + ksum13 += (uint32_t) (v13x13); + ksum13 += (uint32_t) (v13x14); + ksum13 += (uint32_t) (v13x15); + // Subtract 16 zero points (8) + ksum13 -= 128; + + out[104] = (v13x0 | (v13x8 << 4)) ^ 0x88; + out[105] = (v13x1 | (v13x9 << 4)) ^ 0x88; + out[106] = (v13x2 | (v13x10 << 4)) ^ 0x88; + out[107] = (v13x3 | (v13x11 << 4)) ^ 0x88; + out[108] = (v13x4 | (v13x12 << 4)) ^ 0x88; + out[109] = (v13x5 | (v13x13 << 4)) ^ 0x88; + out[110] = (v13x6 | (v13x14 << 4)) ^ 0x88; + out[111] = (v13x7 | (v13x15 << 4)) ^ 0x88; + const uint8_t v14x0 = w14[0] & 0xF; + const uint8_t v14x1 = w14[0] >> 4; + const uint8_t v14x2 = w14[1] & 0xF; + const uint8_t v14x3 = w14[1] >> 4; + const uint8_t v14x4 = w14[2] & 0xF; + const uint8_t v14x5 = w14[2] >> 4; + const uint8_t v14x6 = w14[3] & 0xF; + const uint8_t v14x7 = w14[3] >> 4; + const uint8_t v14x8 = w14[4] & 0xF; + const uint8_t v14x9 = w14[4] >> 4; + const uint8_t v14x10 = w14[5] & 0xF; + const uint8_t v14x11 = w14[5] >> 4; + const uint8_t v14x12 = w14[6] & 0xF; + const uint8_t v14x13 = w14[6] >> 4; + const uint8_t v14x14 = w14[7] & 0xF; + const uint8_t v14x15 = w14[7] >> 4; + w14 += 8; + + ksum14 += (uint32_t) (v14x0); + ksum14 += (uint32_t) (v14x1); + ksum14 += (uint32_t) (v14x2); + ksum14 += (uint32_t) (v14x3); + ksum14 += (uint32_t) (v14x4); + ksum14 += (uint32_t) (v14x5); + ksum14 += (uint32_t) (v14x6); + ksum14 += (uint32_t) (v14x7); + ksum14 += (uint32_t) (v14x8); + ksum14 += (uint32_t) (v14x9); + ksum14 += (uint32_t) (v14x10); + ksum14 += (uint32_t) (v14x11); + ksum14 += (uint32_t) (v14x12); + ksum14 += (uint32_t) (v14x13); + ksum14 += (uint32_t) (v14x14); + ksum14 += (uint32_t) (v14x15); + // Subtract 16 zero points (8) + ksum14 -= 128; + + out[112] = (v14x0 | (v14x8 << 4)) ^ 0x88; + out[113] = (v14x1 | (v14x9 << 4)) ^ 0x88; + out[114] = (v14x2 | (v14x10 << 4)) ^ 0x88; + out[115] = (v14x3 | (v14x11 << 4)) ^ 0x88; + out[116] = (v14x4 | (v14x12 << 4)) ^ 0x88; + out[117] = (v14x5 | (v14x13 << 4)) ^ 0x88; + out[118] = (v14x6 | (v14x14 << 4)) ^ 0x88; + out[119] = (v14x7 | (v14x15 << 4)) ^ 0x88; + + out += 128; + } + float scale0 = math_cvt_fp32_bf16(s0[0]); + float scale1 = math_cvt_fp32_bf16(s1[0]); + float scale2 = math_cvt_fp32_bf16(s2[0]); + float scale3 = math_cvt_fp32_bf16(s3[0]); + float scale4 = math_cvt_fp32_bf16(s4[0]); + float scale5 = math_cvt_fp32_bf16(s5[0]); + float scale6 = math_cvt_fp32_bf16(s6[0]); + float scale7 = math_cvt_fp32_bf16(s7[0]); + float scale8 = math_cvt_fp32_bf16(s8[0]); + float scale9 = math_cvt_fp32_bf16(s9[0]); + float scale10 = math_cvt_fp32_bf16(s10[0]); + float scale11 = math_cvt_fp32_bf16(s11[0]); + float scale12 = math_cvt_fp32_bf16(s12[0]); + float scale13 = math_cvt_fp32_bf16(s13[0]); + float scale14 = math_cvt_fp32_bf16(s14[0]); + s0 += 1; + s1 += 1; + s2 += 1; + s3 += 1; + s4 += 1; + s5 += 1; + s6 += 1; + s7 += 1; + s8 += 1; + s9 += 1; + s10 += 1; + s11 += 1; + s12 += 1; + s13 += 1; + s14 += 1; + + + packed_k_scaled_sum[0] -= (float)ksum0 * izp * scale0; + packed_k_scaled_sum[1] -= (float)ksum1 * izp * scale1; + packed_k_scaled_sum[2] -= (float)ksum2 * izp * scale2; + packed_k_scaled_sum[3] -= (float)ksum3 * izp * scale3; + packed_k_scaled_sum[4] -= (float)ksum4 * izp * scale4; + packed_k_scaled_sum[5] -= (float)ksum5 * izp * scale5; + packed_k_scaled_sum[6] -= (float)ksum6 * izp * scale6; + packed_k_scaled_sum[7] -= (float)ksum7 * izp * scale7; + packed_k_scaled_sum[8] -= (float)ksum8 * izp * scale8; + packed_k_scaled_sum[9] -= (float)ksum9 * izp * scale9; + packed_k_scaled_sum[10] -= (float)ksum10 * izp * scale10; + packed_k_scaled_sum[11] -= (float)ksum11 * izp * scale11; + packed_k_scaled_sum[12] -= (float)ksum12 * izp * scale12; + packed_k_scaled_sum[13] -= (float)ksum13 * izp * scale13; + packed_k_scaled_sum[14] -= (float)ksum14 * izp * scale14; + + ((uint16_t*) out)[0] = math_cvt_bf16_fp32(scale0 / 16.0f); + ((uint16_t*) out)[1] = math_cvt_bf16_fp32(scale1 / 16.0f); + ((uint16_t*) out)[2] = math_cvt_bf16_fp32(scale2 / 16.0f); + ((uint16_t*) out)[3] = math_cvt_bf16_fp32(scale3 / 16.0f); + ((uint16_t*) out)[4] = math_cvt_bf16_fp32(scale4 / 16.0f); + ((uint16_t*) out)[5] = math_cvt_bf16_fp32(scale5 / 16.0f); + ((uint16_t*) out)[6] = math_cvt_bf16_fp32(scale6 / 16.0f); + ((uint16_t*) out)[7] = math_cvt_bf16_fp32(scale7 / 16.0f); + ((uint16_t*) out)[8] = math_cvt_bf16_fp32(scale8 / 16.0f); + ((uint16_t*) out)[9] = math_cvt_bf16_fp32(scale9 / 16.0f); + ((uint16_t*) out)[10] = math_cvt_bf16_fp32(scale10 / 16.0f); + ((uint16_t*) out)[11] = math_cvt_bf16_fp32(scale11 / 16.0f); + ((uint16_t*) out)[12] = math_cvt_bf16_fp32(scale12 / 16.0f); + ((uint16_t*) out)[13] = math_cvt_bf16_fp32(scale13 / 16.0f); + ((uint16_t*) out)[14] = math_cvt_bf16_fp32(scale14 / 16.0f); + + out += 16 * sizeof(uint16_t); + } + + + if XNN_LIKELY(b != NULL){ + size_t nb = n; + do { + *((uint32_t*) out) = *b++; + out += sizeof(uint32_t); + } while(--nb != 0); + } else { + size_t nb = n; + do { + *((uint32_t*) out) = 0; + out += sizeof(uint32_t); + } while(--nb != 0); + } + out += 16 * sizeof(uint32_t); + } + } while (--g != 0); +} diff --git a/src/qb4-packw/kr-scalar.c.in b/src/qb4-packw/kr-scalar.c.in new file mode 100644 index 00000000000..4035846ed29 --- /dev/null +++ b/src/qb4-packw/kr-scalar.c.in @@ -0,0 +1,200 @@ +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$assert NR > 1 +$assert KR > 1 +#include +#include +#include +#include + +#include "xnnpack/packw.h" + +void xnn_qb4_packw_gemm_goi_ukernel_x${NR}c${KR}__scalar( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + size_t bl, + const uint8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes_bl, + size_t extra_bytes_n, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == ${NR}); + assert(kr == ${KR}); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + assert(extra_bytes_bl == nr * sizeof(uint16_t)); + assert(extra_bytes_n == nr * sizeof(float)); + assert(params != NULL); + assert(kc % bl == 0); + size_t num_blocks = kc / bl; + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + const uint32_t izp = (uint32_t) (((const struct xnn_qs8_qc4w_packing_params*) params)->input_zero_point + 0); + + do { + // NC main loop multiple of ${NR} + const uint8_t* w0 = (const uint8_t*) weights; + const uint16_t* s0 = (const uint16_t*) scale; + size_t n = nc; + for (;n >= ${NR}; n -= ${NR}) { + float* packed_k_scaled_sum = (float*) out; + $for N in range(NR): + ((float*) out)[${N}] = 0; + out += ${NR} * sizeof(float); + + // KC/2 bytes is KC Nibbles + $for N in range(1, NR): + const uint8_t* w${N} = w${N-1} + (kc >> 1); + + // scales + $for N in range(1, NR): + const uint16_t* s${N} = s${N-1} + num_blocks; + + + size_t kb = kc; + // Process k by blocks (bl) + for (; kb >= bl; kb-=bl) { + // Initialize KSum as subtracting bl zero points (8) + $for N in range(0, NR): + int32_t ksum${N} = 0; + size_t k = bl; + for(; k >= ${2 * KR}; k-=${2*KR}) { + $for N in range(NR): + $for K in range(0, KR): + const uint8_t v${N}x${2*K} = w${N}[${K}] & 0xF; + const uint8_t v${N}x${2*K+1} = w${N}[${K}] >> 4; + w${N} += ${KR}; + + $for K in range(0, 2*KR): + ksum${N} += (uint32_t) (v${N}x${K}); + // Subtract ${2*KR} zero points (8) + ksum${N} -= ${8*2*KR}; + + $for K in range(0, KR): + out[${N*KR+K}] = (v${N}x${K} | (v${N}x${K+KR} << 4)) ^ 0x88; + + out += ${NR*KR}; + } + $for N in range(NR): + float scale${N} = math_cvt_fp32_bf16(s${N}[0]); + $for N in range(NR): + s${N} += 1; + + + $for N in range(NR): + packed_k_scaled_sum[${N}] -= (float)ksum${N} * izp * scale${N}; + + $for N in range(NR): + ((uint16_t*) out)[${N}] = math_cvt_bf16_fp32(scale${N} / 16.0f); + + out += ${NR} * sizeof(uint16_t); + } + + + if XNN_LIKELY(b != NULL){ + $for N in range(NR): + ((uint32_t*) out)[${N}] = b[${N}]; + b += ${NR}; + } else { + $for N in range(NR): + ((uint32_t*) out)[${N}] = 0; + } + out += ${NR} * sizeof(uint32_t); + w0 = w15; + s0 = s15; + } + + // NC remainder (1..${NR-1}) + if XNN_UNLIKELY(n != 0) { + float* packed_k_scaled_sum = (float*) out; + $for N in range(NR): + ((float*) out)[${N}] = 0; + out += ${NR} * sizeof(float); + $if NR > 2: + // NR remainder has less than ${NR} + $for N in range(1, NR-1): + const uint8_t* w${N} = w${N-1} + (kc >> 1); + const uint16_t* s${N} = s${N-1} + num_blocks; + $if $N % 2 == 0: + if XNN_UNPREDICTABLE(n <= ${N}) { + w${N} = w${N-1}; + s${N} = s${N-1}; + } + $else: + if XNN_UNPREDICTABLE(n < ${N+1}) { + w${N} = w${N-1}; + s${N} = s${N-1}; + } + + size_t kb = kc; + // Process k by blocks (bl) + for (; kb >= bl; kb-=bl) { + // Initialize KSum as subtracting bl zero points (8) + $for N in range(0, NR-1): + int32_t ksum${N} = 0; + size_t k = bl; + for(; k >= ${2 * KR}; k-=${2*KR}) { + $for N in range(NR-1): + $for K in range(0, KR): + const uint8_t v${N}x${2*K} = w${N}[${K}] & 0xF; + const uint8_t v${N}x${2*K+1} = w${N}[${K}] >> 4; + w${N} += ${KR}; + + $for K in range(0, 2*KR): + ksum${N} += (uint32_t) (v${N}x${K}); + // Subtract ${2*KR} zero points (8) + ksum${N} -= ${8*2*KR}; + + $for K in range(0, KR): + out[${N*KR+K}] = (v${N}x${K} | (v${N}x${K+KR} << 4)) ^ 0x88; + + out += ${NR*KR}; + } + $for N in range(NR - 1): + float scale${N} = math_cvt_fp32_bf16(s${N}[0]); + $for N in range(NR - 1): + s${N} += 1; + + + $for N in range(NR-1): + packed_k_scaled_sum[${N}] -= (float)ksum${N} * izp * scale${N}; + + $for N in range(NR-1): + ((uint16_t*) out)[${N}] = math_cvt_bf16_fp32(scale${N} / 16.0f); + + out += ${NR} * sizeof(uint16_t); + } + + + if XNN_LIKELY(b != NULL){ + size_t nb = n; + do { + *((uint32_t*) out) = *b++; + out += sizeof(uint32_t); + } while(--nb != 0); + } else { + size_t nb = n; + do { + *((uint32_t*) out) = 0; + out += sizeof(uint32_t); + } while(--nb != 0); + } + out += ${NR} * sizeof(uint32_t); + } + } while (--g != 0); +} diff --git a/src/qb4-packw/qb4-packw.h b/src/qb4-packw/qb4-packw.h new file mode 100644 index 00000000000..09f1b881dc7 --- /dev/null +++ b/src/qb4-packw/qb4-packw.h @@ -0,0 +1,9 @@ + +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +// arch_flags, ukernel, nr, kr, sr, kblock, bl, nr_scale, izp +XNN_QB4_UKERNEL(0, xnn_qb4_packw_gemm_goi_ukernel_x16c8__scalar, 16, 8, 1, 32, 32, 1, 8) +XNN_QB4_UKERNEL(0, xnn_qb4_packw_gemm_goi_ukernel_x16c4__scalar, 16, 4, 1, 32, 32, 1, 8) diff --git a/src/reference/packing.cc b/src/reference/packing.cc index c39f2aa412c..433f16ad265 100644 --- a/src/reference/packing.cc +++ b/src/reference/packing.cc @@ -588,6 +588,8 @@ void xnn_pack_qs8_qb4w_gemm_goi_w( do { const size_t nr_block_size = min(nc - nr_block_start, nr); float* packed_b = (float*) packed_weights; + // Zero out ksum scaled + copy_bias(bias, nr_block_start, nr_block_size, packed_b); packed_weights = (float*) packed_weights + nr_block_size; packed_weights = (float*) packed_weights + (nr - nr_block_size); @@ -679,6 +681,8 @@ void xnn_pack_qs8_qb4w_gemm_gio_w( do { const size_t nr_block_size = min(nc - nr_block_start, nr); int32_t* packed_b = (int32_t*) packed_weights; + // Zero out ksum scaled + copy_bias(bias, nr_block_start, nr_block_size, packed_b); packed_weights = (float*) packed_weights + nr_block_size; packed_weights = (float*) packed_weights + (nr - nr_block_size); diff --git a/src/xnnpack/microfnptr.h b/src/xnnpack/microfnptr.h index e6df9020b14..92d1db7f6b7 100644 --- a/src/xnnpack/microfnptr.h +++ b/src/xnnpack/microfnptr.h @@ -1376,6 +1376,22 @@ typedef void (*xnn_qs8_qc4w_packw_gemm_goi_ukernel_fn)( size_t extra_bytes, const struct xnn_qs8_qc4w_packing_params* params); +typedef void (*xnn_qb4_packw_gemm_goi_ukernel_fn)( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + size_t bl, + const uint8_t* k, + const int32_t* b, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes_bl, + size_t extra_bytes, + const void* params); + typedef void (*xnn_x16_packw_gemm_goi_ukernel_fn)( size_t g, size_t nc, diff --git a/src/xnnpack/microparams.h b/src/xnnpack/microparams.h index 311489b77a8..c73aee47e0e 100644 --- a/src/xnnpack/microparams.h +++ b/src/xnnpack/microparams.h @@ -625,4 +625,3 @@ struct subconvolution_params { // scaled_kernel_size := kernel_size * mr * sizeof(void*). size_t scaled_kernel_size; }; - diff --git a/src/xnnpack/packw.h b/src/xnnpack/packw.h index 5ca34d3da0b..7dccf46ea54 100644 --- a/src/xnnpack/packw.h +++ b/src/xnnpack/packw.h @@ -88,6 +88,28 @@ extern "C" { #undef XNN_QS8_UKERNEL #undef XNN_QS8_GIO_UKERNEL +#define XNN_QB4_UKERNEL(arch_flags, ukernel, nr_, kr_, sr_, kblock, bl_size, nr_scale, izp) \ + XNN_INTERNAL void ukernel( \ + size_t g, \ + size_t nc, \ + size_t kc, \ + size_t nr, \ + size_t kr, \ + size_t sr, \ + size_t bl, \ + const uint8_t* weights, \ + const int32_t* bias, \ + const void* scale, \ + int8_t* packed_weights, \ + size_t extra_bytes_bl, \ + size_t extra_bytes_n, \ + const void* params); + +#include "qb4-packw/qb4-packw.h" + +#undef XNN_QB4_UKERNEL + + #define XNN_UKERNEL(arch_flags, ukernel, nr_, kr_, sr_, kblock, nr_scale) \ XNN_INTERNAL void ukernel( \ size_t g, \ diff --git a/test/packw-microkernel-tester.h b/test/packw-microkernel-tester.h index 1e30db3214e..08dc56928f9 100644 --- a/test/packw-microkernel-tester.h +++ b/test/packw-microkernel-tester.h @@ -21,6 +21,7 @@ #include "xnnpack/pack.h" #include "xnnpack/buffer.h" #include "replicable_random_device.h" +#include "xnnpack/microparams-init.h" class PackWMicrokernelTester { public: @@ -97,6 +98,15 @@ class PackWMicrokernelTester { return this->k_; } + PackWMicrokernelTester& bl(size_t bl) { + this->bl_ = bl; + return *this; + } + + size_t bl() const { + return this->bl_; + } + PackWMicrokernelTester& nullbias(bool nullbias) { this->nullbias_ = nullbias; return *this; @@ -493,6 +503,133 @@ class PackWMicrokernelTester { } } + void Test(xnn_qb4_packw_gemm_goi_ukernel_fn packw) const { + xnnpack::Buffer weights(XNN_EXTRA_BYTES / sizeof(int8_t) + n() * k()); + xnnpack::Buffer bias(n()); + xnnpack::Buffer packed_w( + packed_n() * packed_k() + packed_n() * sizeof(uint32_t)); + xnnpack::Buffer packed_w_ref( + packed_n() * packed_k() + packed_n() * sizeof(uint32_t)); + xnnpack::Buffer bf16_scales( + n() * (k() / bl()) + ); + + std::iota(weights.begin(), weights.end(), 0); + std::iota(bias.begin(), bias.end(), UINT32_C(15)); + std::fill(packed_w.begin(), packed_w.end(), INT8_C(0)); + std::fill(packed_w_ref.begin(), packed_w_ref.end(), INT8_C(0)); + std::iota(bf16_scales.begin(), bf16_scales.end(), 3.75); + + const int32_t* bias_data = nullbias() ? nullptr : bias.data(); + const xnn_bfloat16* scale_data = bf16_scales.data(); + const xnn_qs8_qc4w_packing_params packing_params = { 1, 8 }; + + // Compute reference results. + xnn_pack_qs8_qb4w_gemm_goi_w(/*g=*/1, n(), k(), nr(), kr(), sr(), bl(), + weights.data(), + nullptr, + /*scale=*/scale_data, + packed_w_ref.data(), + sizeof(uint16_t) * nr(), + /*extra_bytes=*/sizeof(float) * nr(), &packing_params); + + // fill in scale as second step (reference) + size_t k_stride = round_up_po2(k(), kr() * sr() * 2 /* planes */); + k_stride = round_up_po2(k_stride, 2) >> 1; + size_t k_num_blocks = k() / bl(); + size_t k_bytes = sizeof(int8_t) * k_stride * nr(); + size_t bias_bytes = sizeof(float) * nr(); + size_t ksum_bytes = sizeof(float) * nr(); + size_t block_bytes = sizeof(uint16_t) * k_num_blocks * nr(); + + size_t start_offset = ksum_bytes + k_bytes / k_num_blocks; + size_t stride = ksum_bytes + k_bytes + block_bytes + bias_bytes; + size_t block_stride = (bl() * nr()) / 2 + (sizeof(uint16_t) * nr()); + + xnn_init_blockwise_scale_bf16_params( + /*channels=*/n(), + /*channels_tile=*/nr(), + /*channel_subtile=*/nr(), + /*stride=*/stride, + /*substride=*/stride, + /*num_blocks=*/k_num_blocks, + /*block_stride=*/block_stride, + /*stride_offset=*/0, + /*scale=*/scale_data, + /*packed_w=*/packed_w_ref.data() + start_offset); + + void* bias_start = (void*) ((uintptr_t) packed_w_ref.data() + stride - nr() * sizeof(float)); + + if (!nullbias()){ + xnn_init_qs8_qc8w_scale_fp32_params( + n(), nr(), nr(), stride, stride, 0, (float*) bias_data, bias_start + ); + } + + // Call optimized micro-kernel. + packw(/*g=*/1, n(), k(), nr(), kr(), sr(), bl(), + weights.data(), bias_data, /*scale=*/scale_data, packed_w.data(), sizeof(uint16_t) * nr(), /*extra_bytes=*/sizeof(float) * nr(), &packing_params); + + const uint8_t* packed_data = (uint8_t*)packed_w.data(); + const uint8_t* packed_ref_data = (uint8_t*)packed_w_ref.data(); + + + // Compare Packed Tensors. + for(size_t n_block_start = 0; n_block_start < packed_n(); n_block_start+=nr()){ + // Number of output channels in this block + size_t n_remainder = min(nr(), n() - n_block_start); + // Check KScaledSums + float* kscale_sum_start = (float*) packed_data; + float* kscale_sum_ref_start = (float*) packed_ref_data; + for(size_t ni = 0; ni < n_remainder; ni++){ + EXPECT_EQ((float) kscale_sum_start[ni], (float)kscale_sum_ref_start[ni]) + << "kscaled sum at index: " << ni << " of n_block_start: " << n_block_start << "\n"; + } + + packed_data += nr() * sizeof(float); + packed_ref_data += nr() * sizeof(float); + + for (size_t bl_start = 0; bl_start < k(); bl_start+=bl()){ + // Check nibbles + size_t num_planes_block = bl() / (2 * kr()); + for (size_t pi = 0; pi < num_planes_block; pi += 1) { + for(size_t ni = 0; ni < n_remainder; ni++) { + for(size_t ki = 0; ki < 2*kr(); ki++) { + size_t i = (2 * kr()) * (nr() * pi + ni) + ki; + uint8_t val_ref = ((i & 1) ? (uint8_t)packed_ref_data[i>>1] >> 4 : packed_ref_data[i>>1] &0xF); + uint8_t val = ((i & 1) ? (uint8_t)packed_data[i>>1] >> 4 : packed_data[i>>1] &0xF); + EXPECT_EQ(val_ref, val) << " nibbles do not match location at \n" + << "nr_block_start: " << n_block_start << ", plane: " << pi << "\n" + << " ni: " << ni << " ki: " << ki << " i: " << i << "\n"; + } + } + } + packed_data += ((bl() * nr()) >> 1) * sizeof(uint8_t); + packed_ref_data += ((bl() * nr()) >> 1) * sizeof(uint8_t); + // check scales + uint16_t* scales_start = (uint16_t*) packed_data; + uint16_t* scales_ref_start = (uint16_t*) packed_ref_data; + for(size_t ni = 0; ni < n_remainder; ni++){ + // Packing divides the scales by 16, multiplying back is a bit easier for readability + EXPECT_EQ(math_cvt_fp32_bf16(scales_start[ni]) * 16, math_cvt_fp32_bf16(scales_ref_start[ni]) * 16) + << "n_block_start " << n_block_start << " ni " << ni; + } + + packed_data += nr() * sizeof(uint16_t); + packed_ref_data += nr() * sizeof(uint16_t); + } + // check bias + uint32_t* bias_start = (uint32_t*) packed_data; + uint32_t* bias_ref_start = (uint32_t*) packed_ref_data; + for(size_t ni = 0; ni < n_remainder; ni++){ + EXPECT_EQ(bias_start[ni], bias_ref_start[ni]) + << "n_block_start " << n_block_start << " ni " << ni; + } + packed_ref_data += nr() * sizeof(uint32_t); + packed_data += nr() * sizeof(uint32_t); + } + } + private: size_t g_{1}; size_t n_{1}; @@ -500,6 +637,7 @@ class PackWMicrokernelTester { size_t kr_{1}; size_t sr_{1}; size_t k_{1}; + size_t bl_{1}; bool nullbias_{false}; size_t izp_{0}; }; diff --git a/test/qb4-packw.cc b/test/qb4-packw.cc new file mode 100644 index 00000000000..b20461e9541 --- /dev/null +++ b/test/qb4-packw.cc @@ -0,0 +1,148 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include +#include "xnnpack/common.h" +#include "xnnpack/isa-checks.h" +#include "xnnpack/packw.h" +#include "next_prime.h" +#include "packw-microkernel-tester.h" + +namespace { + +struct XnnTestQB4Param { + const char *name; + xnn_qb4_packw_gemm_goi_ukernel_fn ukernel; + uint64_t arch_flags; + size_t nr, kr, sr, kblock, bl, nr_scale, izp; +}; + +class XnnTestQB4 : public testing::TestWithParam { +}; + +std::string GetTestQB4Name(const testing::TestParamInfo& info) { + return info.param.name; +} + +#define XNN_QB4_UKERNEL(arch_flags, ukernel, nr, kr, sr, kblock, bl, nr_scale, izp) \ + { #ukernel, ukernel, arch_flags, nr, kr, sr, kblock, bl, nr_scale, izp }, + +const XnnTestQB4Param xnn_test_qb4_params[] = { +#include "qb4-packw/qb4-packw.h" +}; + +#undef XNN_QB4_UKERNEL + +TEST_P(XnnTestQB4, null_bias) { + TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); + PackWMicrokernelTester() + .nullbias(true) + .n(GetParam().nr * GetParam().nr_scale) + .k(GetParam().kblock) + .nr(GetParam().nr * GetParam().nr_scale) + .kr(GetParam().kr) + .sr(GetParam().sr) + .bl(GetParam().bl) + .izp(GetParam().izp) + .Test(GetParam().ukernel); +} + +TEST_P(XnnTestQB4, bias) { + TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); + PackWMicrokernelTester() + .nullbias(false) + .n(GetParam().nr * GetParam().nr_scale) + .k(GetParam().kblock) + .nr(GetParam().nr * GetParam().nr_scale) + .kr(GetParam().kr) + .sr(GetParam().sr) + .bl(GetParam().bl) + .izp(GetParam().izp) + .Test(GetParam().ukernel); +} + +TEST_P(XnnTestQB4, kb_gt_bl_no_bias) { + TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); + PackWMicrokernelTester() + .nullbias(true) + .n(GetParam().nr * GetParam().nr_scale) + .k(GetParam().kblock * 2) + .nr(GetParam().nr * GetParam().nr_scale) + .kr(GetParam().kr) + .sr(GetParam().sr) + .bl(GetParam().bl) + .izp(GetParam().izp) + .Test(GetParam().ukernel); +} + +TEST_P(XnnTestQB4, kb_gt_bl_bias) { + TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); + PackWMicrokernelTester() + .nullbias(false) + .n(GetParam().nr * GetParam().nr_scale) + .k(GetParam().kblock * 2) + .nr(GetParam().nr * GetParam().nr_scale) + .kr(GetParam().kr) + .sr(GetParam().sr) + .bl(GetParam().bl) + .izp(GetParam().izp) + .Test(GetParam().ukernel); +} + +TEST_P(XnnTestQB4, nr_divides_nc) { + TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); + PackWMicrokernelTester() + .nullbias(true) + .n(GetParam().nr * GetParam().nr_scale * 2) + .k(GetParam().kblock) + .nr(GetParam().nr * GetParam().nr_scale) + .kr(GetParam().kr) + .sr(GetParam().sr) + .bl(GetParam().bl) + .izp(GetParam().izp) + .Test(GetParam().ukernel); +} + +TEST_P(XnnTestQB4, nr_divides_nc_with_bias) { + TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); + PackWMicrokernelTester() + .nullbias(false) + .n(GetParam().nr * GetParam().nr_scale * 2) + .k(GetParam().kblock) + .nr(GetParam().nr * GetParam().nr_scale) + .kr(GetParam().kr) + .sr(GetParam().sr) + .bl(GetParam().bl) + .izp(GetParam().izp) + .Test(GetParam().ukernel); +} + +TEST_P(XnnTestQB4, nc_gt_nr) { + TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); + // for(size_t ni = 1; ni < GetParam().nr; ++ni){ + PackWMicrokernelTester() + .nullbias(false) + .n(2) + .k(GetParam().kblock) + .nr(GetParam().nr * GetParam().nr_scale) + .kr(GetParam().kr) + .sr(GetParam().sr) + .bl(GetParam().bl) + .izp(GetParam().izp) + .Test(GetParam().ukernel); + // } +} + + +INSTANTIATE_TEST_SUITE_P(qb4_packw, + XnnTestQB4, + testing::ValuesIn(xnn_test_qb4_params), + GetTestQB4Name); + + +} // namespace From 3f0382475af170da339421c93ccb70d990ba1637 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 26 Nov 2024 16:34:01 -0500 Subject: [PATCH 4/5] [Fast Packing] Add packing ukernels to gemm config --- BUILD.bazel | 5 +- CMakeLists.txt | 2 +- cmake/gen/scalar_microkernels.cmake | 4 +- gen/scalar_microkernels.bzl | 4 +- src/configs/gemm-config.c | 3 + src/packw.c | 105 ++++++++++++++++++++++++++++ src/xnnpack/pack.h | 38 ++++++++++ tools/update-microkernels.py | 8 +++ 8 files changed, 163 insertions(+), 6 deletions(-) create mode 100644 src/packw.c diff --git a/BUILD.bazel b/BUILD.bazel index 0a7cdef9b2f..245a722a444 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -783,7 +783,10 @@ xnnpack_cc_library( xnnpack_cxx_library( name = "packing", - srcs = ["src/reference/packing.cc"], + srcs = [ + "src/reference/packing.cc", + "src/packw.c" + ], hdrs = ["src/xnnpack/pack.h"], defines = xnnpack_configurable_defines(), deps = [ diff --git a/CMakeLists.txt b/CMakeLists.txt index 1b1688ed70b..18a8d92ebaa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -793,7 +793,7 @@ ADD_LIBRARY(indirection OBJECT src/indirection.c) ADD_LIBRARY(logging OBJECT ${LOGGING_SRCS}) ADD_LIBRARY(microparams-init OBJECT src/microparams-init.c) ADD_LIBRARY(normalization OBJECT src/normalization.c) -ADD_LIBRARY(packing OBJECT src/reference/packing.cc) +ADD_LIBRARY(packing OBJECT src/reference/packing.cc src/packw.c) TARGET_LINK_LIBRARIES(hardware-config PRIVATE xnnpack-base logging) TARGET_LINK_LIBRARIES(indirection PRIVATE xnnpack-base) TARGET_LINK_LIBRARIES(logging PRIVATE xnnpack-base) diff --git a/cmake/gen/scalar_microkernels.cmake b/cmake/gen/scalar_microkernels.cmake index be694e9477e..efd75af706e 100644 --- a/cmake/gen/scalar_microkernels.cmake +++ b/cmake/gen/scalar_microkernels.cmake @@ -135,6 +135,8 @@ SET(PROD_SCALAR_MICROKERNEL_SRCS src/f32-vunary/gen/f32-vabs-scalar.c src/f32-vunary/gen/f32-vneg-scalar.c src/f32-vunary/gen/f32-vsqr-scalar.c + src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c + src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x4-minmax-scalar.c src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-4x4-minmax-scalar.c src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4-minmax-scalar.c @@ -541,8 +543,6 @@ SET(NON_PROD_SCALAR_MICROKERNEL_SRCS src/f32-vsigmoid/gen/f32-vsigmoid-scalar-rr2-p5-div-u4.c src/f32-vsqrt/gen/f32-vsqrt-scalar-sqrt-u2.c src/f32-vsqrt/gen/f32-vsqrt-scalar-sqrt-u4.c - src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c - src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c diff --git a/gen/scalar_microkernels.bzl b/gen/scalar_microkernels.bzl index fe920fe5c76..f75030ff538 100644 --- a/gen/scalar_microkernels.bzl +++ b/gen/scalar_microkernels.bzl @@ -131,6 +131,8 @@ PROD_SCALAR_MICROKERNEL_SRCS = [ "src/f32-vunary/gen/f32-vabs-scalar.c", "src/f32-vunary/gen/f32-vneg-scalar.c", "src/f32-vunary/gen/f32-vsqr-scalar.c", + "src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c", + "src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c", "src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x4-minmax-scalar.c", "src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-4x4-minmax-scalar.c", "src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4-minmax-scalar.c", @@ -538,8 +540,6 @@ NON_PROD_SCALAR_MICROKERNEL_SRCS = [ "src/f32-vsigmoid/gen/f32-vsigmoid-scalar-rr2-p5-div-u4.c", "src/f32-vsqrt/gen/f32-vsqrt-scalar-sqrt-u2.c", "src/f32-vsqrt/gen/f32-vsqrt-scalar-sqrt-u4.c", - "src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c", - "src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c", "src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c", "src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c", "src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c", diff --git a/src/configs/gemm-config.c b/src/configs/gemm-config.c index d5236d1f5d4..357ecbb1a8d 100644 --- a/src/configs/gemm-config.c +++ b/src/configs/gemm-config.c @@ -1795,6 +1795,7 @@ static void init_qd8_f32_qb4w_gemm_config(void) { qd8_f32_qb4w_gemm_config.nr = 16; qd8_f32_qb4w_gemm_config.log2_kr = 2; qd8_f32_qb4w_gemm_config.planes = 2; + qd8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_x16c4_weights_and_biases; #endif // XNN_ENABLE_ARM_DOTPROD } else { qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16__neon_mlal_lane); @@ -1824,6 +1825,7 @@ static void init_qd8_f32_qb4w_gemm_config(void) { qd8_f32_qb4w_gemm_config.nr = 16; qd8_f32_qb4w_gemm_config.log2_kr = 3; qd8_f32_qb4w_gemm_config.planes = 2; + qd8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_x16c8_weights_and_biases; #endif // XNN_ENABLE_ARM_I8MM } else if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) { #if XNN_ENABLE_ARM_DOTPROD @@ -1834,6 +1836,7 @@ static void init_qd8_f32_qb4w_gemm_config(void) { qd8_f32_qb4w_gemm_config.nr = 16; qd8_f32_qb4w_gemm_config.log2_kr = 2; qd8_f32_qb4w_gemm_config.planes = 2; + qd8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_x16c4_weights_and_biases; #endif // XNN_ENABLE_ARM_DOTPROD } else { qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16__neon_mlal_lane); diff --git a/src/packw.c b/src/packw.c new file mode 100644 index 00000000000..9d53896b902 --- /dev/null +++ b/src/packw.c @@ -0,0 +1,105 @@ + +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// Copyright 2019 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include "xnnpack.h" +#include "xnnpack/common.h" +#include "xnnpack/config-types.h" +#include "xnnpack/log.h" +#include "xnnpack/math.h" +#include "xnnpack/microfnptr.h" +#include "xnnpack/microparams.h" +#include "xnnpack/microparams-init.h" +#include "xnnpack/packw.h" +#include "xnnpack/pack.h" +#include "xnnpack/unaligned.h" + +void xnn_pack_qb4_x16c8_weights_and_biases( + uint32_t flags, const struct xnn_gemm_config* gemm_config, + size_t input_channels, size_t output_channels, size_t groups, + size_t block_size, size_t k_stride, const void* accumulator_init, const void* weights, + xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, + size_t extra_data0_element_size, + xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, + size_t extra_data1_element_size, void* packed_weights_ptr, + const void* params) { + if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { + // No packing ukernel for gio + return xnn_pack_qb4_weights_and_biases( + flags, gemm_config, input_channels, output_channels, groups, + block_size, k_stride, accumulator_init, weights, init_extra_data0_fn, + extra_data0, extra_data0_element_size, init_extra_data1_fn, extra_data1, + extra_data1_element_size, packed_weights_ptr, params); + } + const uint32_t nr = gemm_config->nr; + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + const size_t planes = gemm_config->planes; + + const size_t extra_bytes_bl = sizeof(uint16_t); + const size_t extra_bytes_n = sizeof(uint32_t); + + xnn_qb4_packw_gemm_goi_ukernel_x16c8__scalar( + /*g=*/groups, + /*nc=*/output_channels, + /*kc=*/input_channels, + /*nr=*/nr, + /*kr=*/kr, + /*sr=*/sr, + /*bl=*/block_size, + /*k=*/(const uint8_t*)weights, + /*bias=*/(const int32_t*)accumulator_init, + /*scale=*/(const xnn_bfloat16*)extra_data1, + /*packed_weights=*/(int8_t*)packed_weights_ptr, + /*extra_bytes_bl=*/nr * extra_bytes_bl, + /*extra_bytes_n=*/nr * extra_bytes_n, + /*params*/(const struct xnn_qs8_qc4w_packing_params *)params); +} + +void xnn_pack_qb4_x16c4_weights_and_biases( + uint32_t flags, const struct xnn_gemm_config* gemm_config, + size_t input_channels, size_t output_channels, size_t groups, + size_t block_size, size_t k_stride, const void* accumulator_init, const void* weights, + xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, + size_t extra_data0_element_size, + xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, + size_t extra_data1_element_size, void* packed_weights_ptr, + const void* params) { + if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { + // No packing ukernel for gio + return xnn_pack_qb4_weights_and_biases( + flags, gemm_config, input_channels, output_channels, groups, + block_size, k_stride, accumulator_init, weights, init_extra_data0_fn, + extra_data0, extra_data0_element_size, init_extra_data1_fn, extra_data1, + extra_data1_element_size, packed_weights_ptr, params); + } + const uint32_t nr = gemm_config->nr; + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + const size_t planes = gemm_config->planes; + + const size_t extra_bytes_bl = sizeof(uint16_t); + const size_t extra_bytes_n = sizeof(uint32_t); + + xnn_qb4_packw_gemm_goi_ukernel_x16c4__scalar( + /*g=*/groups, + /*nc=*/output_channels, + /*kc=*/input_channels, + /*nr=*/nr, + /*kr=*/kr, + /*sr=*/sr, + /*bl=*/block_size, + /*k=*/(const uint8_t*)weights, + /*bias=*/(const int32_t*)accumulator_init, + /*scale=*/(const xnn_bfloat16*)extra_data1, + /*packed_weights=*/(int8_t*)packed_weights_ptr, + /*extra_bytes_bl=*/nr * extra_bytes_bl, + /*extra_bytes_n=*/nr * extra_bytes_n, + /*params*/(const struct xnn_qs8_qc4w_packing_params *)params); +} diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h index c58617bec60..7a6f6ba9b15 100644 --- a/src/xnnpack/pack.h +++ b/src/xnnpack/pack.h @@ -453,6 +453,44 @@ XNN_INTERNAL void xnn_pack_qb4_weights_and_biases( void* packed_weights_ptr, // const void* params); +XNN_INTERNAL void xnn_pack_qb4_x16c4_weights_and_biases( + uint32_t flags, // + const struct xnn_gemm_config* gemm_config, // + size_t input_channels, // + size_t output_channels, // + size_t groups, // + size_t block_size, // + size_t k_stride, // + const void* accumulator_init, // + const void* weights, // + xnn_init_scale_params_fn init_extra_data0_fn, // + const void* extra_data0, // + size_t extra_data0_element_size, // + xnn_init_scale_params_fn init_extra_data1_fn, // + const void* extra_data1, // + size_t extra_data1_element_size, // + void* packed_weights_ptr, // + const void* params); + +XNN_INTERNAL void xnn_pack_qb4_x16c8_weights_and_biases( + uint32_t flags, // + const struct xnn_gemm_config* gemm_config, // + size_t input_channels, // + size_t output_channels, // + size_t groups, // + size_t block_size, // + size_t k_stride, // + const void* accumulator_init, // + const void* weights, // + xnn_init_scale_params_fn init_extra_data0_fn, // + const void* extra_data0, // + size_t extra_data0_element_size, // + xnn_init_scale_params_fn init_extra_data1_fn, // + const void* extra_data1, // + size_t extra_data1_element_size, // + void* packed_weights_ptr, // + const void* params); + XNN_INTERNAL size_t xnn_packed_stride_qb4_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // diff --git a/tools/update-microkernels.py b/tools/update-microkernels.py index e6ecd944c78..f153f4d25c3 100755 --- a/tools/update-microkernels.py +++ b/tools/update-microkernels.py @@ -297,6 +297,14 @@ def main(args): content = config_file.read() microkernels = re.findall(_MICROKERNEL_NAME_REGEX, content) prod_microkernels.update(microkernels) + # Also check prod packing ukernels in packw.c + with open( + os.path.join(src_dir, 'packw.c'), 'r', encoding='utf-8' + ) as packw_file: + content = packw_file.read() + microkernels = re.findall(_MICROKERNEL_NAME_REGEX, content) + prod_microkernels.update(microkernels) + prod_microkernels = set( map(microkernel_name_to_filename.get, prod_microkernels) ) From 452b90ddacb2580c969e33159a05a3c63fcdaa4b Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 26 Nov 2024 11:47:58 -0500 Subject: [PATCH 5/5] [WIP] Packw Benchmarks --- CMakeLists.txt | 1 + bench/packw-benchmark.h | 67 +++++++++++++++++++++++++++++++++++++++++ bench/qb4-packw.cc | 32 ++++++++++++++++++++ bench/qs8-packw.cc | 1 - 4 files changed, 100 insertions(+), 1 deletion(-) create mode 100644 bench/qb4-packw.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 18a8d92ebaa..882f851c772 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1947,6 +1947,7 @@ IF(XNNPACK_BUILD_BENCHMARKS) qu8-gemm qu8-gemm-fp32 qu8-gemm-rndnu + qb4-packw x16-packw x32-packw x8-lut diff --git a/bench/packw-benchmark.h b/bench/packw-benchmark.h index fe11340878a..39b61a892d3 100644 --- a/bench/packw-benchmark.h +++ b/bench/packw-benchmark.h @@ -136,6 +136,73 @@ static void x8_gio_packw(benchmark::State& state, benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate); } +static void qb4_packw(benchmark::State& state, + xnn_qb4_packw_gemm_goi_ukernel_fn packw, + size_t nr, size_t kr, size_t sr, size_t bl, + benchmark::utils::IsaCheckFunction isa_check = nullptr) +{ + if (isa_check != nullptr && !isa_check(state)) { + return; + } + + const size_t batch = 1; // batch is g parameter for packw + const size_t dim_n = state.range(2); // dim_n is nc parameter + const size_t dim_k = state.range(3); // dim_k is kc parameter + + const size_t rounded_n = benchmark::utils::RoundUp(dim_n, nr); + const size_t rounded_k = benchmark::utils::RoundUp(dim_k, bl); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + + // Computer num_buffers that fit cache with source weights + packed_weights. + const size_t num_buffers = 1 + + benchmark::utils::DivideRoundUp(benchmark::utils::GetMaxCacheSize(), + batch * (dim_n * dim_k + rounded_n * rounded_k + rounded_n * sizeof(uint32_t))); + + xnnpack::Buffer weights(num_buffers * batch * + dim_n * (rounded_k >> 1)); + xnnpack::fill_uniform_random_bits(weights.data(), weights.size(), rng); + xnnpack::Buffer packed_weights( + num_buffers * batch * + (rounded_n * (rounded_k >> 1) + rounded_n * sizeof(uint32_t))); + xnnpack::Buffer bias(num_buffers * batch * dim_n); + xnnpack::fill_uniform_random_bits(bias.data(), bias.size(), rng); + size_t num_blocks = rounded_k / bl; + xnnpack::Buffer bf16_scales(num_blocks * batch * dim_n); + xnnpack::fill_uniform_random_bits(bf16_scales.data(), bf16_scales.size(), rng); + + const xnn_qs8_qc4w_packing_params packing_params = { 1, 8 }; + + size_t buffer_index = 0; + for (auto _ : state) { + if (++buffer_index == num_buffers) { + buffer_index = 0; + } + + packw(1, dim_n, rounded_k, nr, kr, sr, bl, + weights.data() + buffer_index * batch * dim_n * (rounded_k >> 1), + /*bias=*/bias.data() + buffer_index * batch * dim_n, + /*scale=*/bf16_scales.data() + buffer_index * batch * dim_n, + packed_weights.data() + buffer_index * batch * (rounded_n * (rounded_k >> 1) + rounded_n * sizeof(uint32_t) + rounded_n * sizeof(uint16_t)), + /*extra_bytes_bl=*/sizeof(uint16_t) * nr, sizeof(float), &packing_params); + } + + const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency(); + if (cpu_frequency != 0) { + state.counters["cpufreq"] = cpu_frequency; + } + + const size_t elements_per_iteration = batch * dim_n * (rounded_k >> 1); + state.counters["elements"] = + benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate); + + const size_t bytes_per_iteration = (elements_per_iteration + batch * (rounded_n * rounded_k + rounded_n * sizeof(uint32_t))); + state.counters["bytes"] = + benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate); +} + + static void qs8_packw(benchmark::State& state, xnn_qs8_packw_gemm_goi_ukernel_fn packw, size_t nr, size_t kr, size_t sr, diff --git a/bench/qb4-packw.cc b/bench/qb4-packw.cc new file mode 100644 index 00000000000..b378b331fe1 --- /dev/null +++ b/bench/qb4-packw.cc @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include "bgemm.h" +#include "packw-benchmark.h" +#include "utils.h" +#include "xnnpack/common.h" +#include "xnnpack/hardware-config.h" +#include "xnnpack/packw.h" + +static void qb4_packw(benchmark::State& state, const char* net, + xnn_qb4_packw_gemm_goi_ukernel_fn ukernel, + uint64_t arch_flags, size_t nr, size_t kr, size_t sr, size_t bl) { + benchmark::utils::CheckArchFlags(state, arch_flags); + qb4_packw(state, ukernel, nr, kr, sr, bl); +} + +#define XNN_QB4_UKERNEL(arch_flags, ukernel, nr, kr, sr, bl, kblock, nr_scale, izp) \ +BENCHMARK_CAPTURE_BGEMM(qb4_packw, ukernel##_, ukernel, arch_flags, nr, kr, sr, bl); + +#include "qb4-packw/qb4-packw.h" + +#undef XNN_QB4_UKERNEL + + +#ifndef XNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/bench/qs8-packw.cc b/bench/qs8-packw.cc index e4d8ee97cbb..a0d1e860efa 100644 --- a/bench/qs8-packw.cc +++ b/bench/qs8-packw.cc @@ -40,4 +40,3 @@ BENCHMARK_CAPTURE_BGEMM(qs8_gio_packw, ukernel##_, ukernel, arch_flags, nr, kr, #ifndef XNNPACK_BENCHMARK_NO_MAIN BENCHMARK_MAIN(); #endif -