/*************************************************************************** * Copyright (c) 2025, The OpenBLAS Project * All rights reserved. * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are * met: * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in * the documentation and/or other materials provided with the * distribution. * 3. Neither the name of the OpenBLAS project nor the names of * its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. * *****************************************************************************/ #include #include #include "common.h" int CNAME(BLASLONG m, BLASLONG n, IFLOAT *input, BLASLONG lda, IFLOAT *output) { const int sve_size_bf16 = svcnth(); const int num_accumulators_sve = sve_size_bf16 >> 1; const int num_accumulators = num_accumulators_sve; const int incr_accumulators = 4; const int n_sve_accumulators = (n & -num_accumulators); const int n2 = n & -2; const int n_rest = n - n2; const int m4 = m & -4; const int m_rest = m - m4; size_t n_step = 0; for (; n_step < n_sve_accumulators; n_step += num_accumulators) { const uint16_t* inner_input = input; // Full 4x4 item transposes down the M dimension for (size_t m_step = 0; m_step < m4; m_step += 4) { const uint16_t* tile = inner_input; for (size_t line = 0; line < num_accumulators; line += incr_accumulators) { // Load 4x4 block uint16x4_t a_vec0 = vld1_u16(tile); uint16x4_t a_vec1 = vld1_u16(tile + lda); uint16x4_t a_vec2 = vld1_u16(tile + 2 * lda); uint16x4_t a_vec3 = vld1_u16(tile + 3 * lda); // Transpose 4x4 blocks uint16x4_t out_vec0 = vzip1_u16(a_vec0, a_vec1); uint16x4_t out_vec1 = vzip2_u16(a_vec0, a_vec1); uint16x4_t out_vec2 = vzip1_u16(a_vec2, a_vec3); uint16x4_t out_vec3 = vzip2_u16(a_vec2, a_vec3); // Transpose 8x4 blocks a_vec0 = vreinterpret_u16_u32(vzip1_u32(vreinterpret_u32_u16(out_vec0), vreinterpret_u32_u16(out_vec2))); a_vec1 = vreinterpret_u16_u32(vzip2_u32(vreinterpret_u32_u16(out_vec0), vreinterpret_u32_u16(out_vec2))); a_vec2 = vreinterpret_u16_u32(vzip1_u32(vreinterpret_u32_u16(out_vec1), vreinterpret_u32_u16(out_vec3))); a_vec3 = vreinterpret_u16_u32(vzip2_u32(vreinterpret_u32_u16(out_vec1), vreinterpret_u32_u16(out_vec3))); vst1_u16(output, a_vec0); vst1_u16(output + 4, a_vec1); vst1_u16(output + 8, a_vec2); vst1_u16(output + 12, a_vec3); tile += incr_accumulators; output += 16; } inner_input += incr_accumulators * lda; } if (m_rest) { for (BLASLONG line = 0; line < num_accumulators; line++) { output[0] = inner_input[0]; output[1] = m_rest == 1 ? 0 : *(inner_input + lda); output[2] = m_rest <= 2 ? 0 : *(inner_input + 2 * lda); output[3] = m_rest <= 3 ? 0 : *(inner_input + 3 * lda); inner_input++; output += 4; } } input += num_accumulators; } for (; n_step < n2; n_step += 2) { const uint16_t* inner_input = input; for (size_t m_step = 0; m_step < m4; m_step += 4) { for (BLASLONG line = 0; line < 2; line++) { output[0] = *(inner_input + line); output[1] = *(inner_input + line + lda); output[2] = *(inner_input + line + 2 * lda); output[3] = *(inner_input + line + 3 * lda); output += 4; } inner_input += 4 * lda; } if (m_rest) { for (BLASLONG line = 0; line < 2; line++) { output[0] = *(inner_input + line); output[1] = m_rest == 1 ? 0 : *(inner_input + line + lda); output[2] = m_rest <= 2 ? 0 : *(inner_input + line + 2 * lda); output[3] = m_rest <= 3 ? 0 : *(inner_input + line + 3 * lda); output += 4; } } input += 2; } if (n_rest & 1) { const uint16_t* inner_input = input; for (size_t m_step = 0; m_step < m4; m_step += 4) { output[0] = *inner_input; output[1] = *(inner_input + lda); output[2] = *(inner_input + 2 * lda); output[3] = *(inner_input + 3 * lda); inner_input += 4 * lda; output += 4; } if (m_rest) { output[0] = inner_input[0]; output[1] = m_rest == 1 ? 0 : *(inner_input + lda); output[2] = m_rest <= 2 ? 0 : *(inner_input + 2 * lda); output[3] = m_rest <= 3 ? 0 : *(inner_input + 3 * lda); output += 4; } } return 0; }