|
@@ -0,0 +1,447 @@ |
|
|
|
|
|
/*********************************************************************/ |
|
|
|
|
|
/* Copyright 2024, The OpenBLAS Project. */ |
|
|
|
|
|
/* All rights reserved. */ |
|
|
|
|
|
/* */ |
|
|
|
|
|
/* Redistribution and use in source and binary forms, with or */ |
|
|
|
|
|
/* without modification, are permitted provided that the following */ |
|
|
|
|
|
/* conditions are met: */ |
|
|
|
|
|
/* */ |
|
|
|
|
|
/* 1. Redistributions of source code must retain the above */ |
|
|
|
|
|
/* copyright notice, this list of conditions and the following */ |
|
|
|
|
|
/* disclaimer. */ |
|
|
|
|
|
/* */ |
|
|
|
|
|
/* 2. Redistributions in binary form must reproduce the above */ |
|
|
|
|
|
/* copyright notice, this list of conditions and the following */ |
|
|
|
|
|
/* disclaimer in the documentation and/or other materials */ |
|
|
|
|
|
/* provided with the distribution. */ |
|
|
|
|
|
/* */ |
|
|
|
|
|
/* THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF TEXAS AT */ |
|
|
|
|
|
/* AUSTIN ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, */ |
|
|
|
|
|
/* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF */ |
|
|
|
|
|
/* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE */ |
|
|
|
|
|
/* DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT */ |
|
|
|
|
|
/* AUSTIN OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, */ |
|
|
|
|
|
/* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES */ |
|
|
|
|
|
/* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE */ |
|
|
|
|
|
/* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR */ |
|
|
|
|
|
/* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF */ |
|
|
|
|
|
/* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT */ |
|
|
|
|
|
/* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT */ |
|
|
|
|
|
/* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE */ |
|
|
|
|
|
/* POSSIBILITY OF SUCH DAMAGE. */ |
|
|
|
|
|
/* */ |
|
|
|
|
|
/*********************************************************************/ |
|
|
|
|
|
|
|
|
|
|
|
#include <stdio.h> |
|
|
|
|
|
#include <stdlib.h> |
|
|
|
|
|
#include "common.h" |
|
|
|
|
|
|
|
|
|
|
|
#define SMP_THRESHOLD_MIN 65536.0 |
|
|
|
|
|
#define ERROR_NAME "SBGEMMT " |
|
|
|
|
|
|
|
|
|
|
|
#ifndef GEMM_MULTITHREAD_THRESHOLD |
|
|
|
|
|
#define GEMM_MULTITHREAD_THRESHOLD 4 |
|
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
#ifndef CBLAS |
|
|
|
|
|
|
|
|
|
|
|
void NAME(char *UPLO, char *TRANSA, char *TRANSB, |
|
|
|
|
|
blasint * M, blasint * K, |
|
|
|
|
|
FLOAT * Alpha, |
|
|
|
|
|
IFLOAT * a, blasint * ldA, |
|
|
|
|
|
IFLOAT * b, blasint * ldB, FLOAT * Beta, FLOAT * c, blasint * ldC) |
|
|
|
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
blasint m, k; |
|
|
|
|
|
blasint lda, ldb, ldc; |
|
|
|
|
|
int transa, transb, uplo; |
|
|
|
|
|
blasint info; |
|
|
|
|
|
|
|
|
|
|
|
char transA, transB, Uplo; |
|
|
|
|
|
blasint nrowa, nrowb; |
|
|
|
|
|
IFLOAT *buffer; |
|
|
|
|
|
IFLOAT *aa, *bb; |
|
|
|
|
|
FLOAT *cc; |
|
|
|
|
|
FLOAT alpha, beta; |
|
|
|
|
|
|
|
|
|
|
|
PRINT_DEBUG_NAME; |
|
|
|
|
|
|
|
|
|
|
|
m = *M; |
|
|
|
|
|
k = *K; |
|
|
|
|
|
|
|
|
|
|
|
alpha = *Alpha; |
|
|
|
|
|
beta = *Beta; |
|
|
|
|
|
|
|
|
|
|
|
lda = *ldA; |
|
|
|
|
|
ldb = *ldB; |
|
|
|
|
|
ldc = *ldC; |
|
|
|
|
|
|
|
|
|
|
|
transA = *TRANSA; |
|
|
|
|
|
transB = *TRANSB; |
|
|
|
|
|
Uplo = *UPLO; |
|
|
|
|
|
TOUPPER(transA); |
|
|
|
|
|
TOUPPER(transB); |
|
|
|
|
|
TOUPPER(Uplo); |
|
|
|
|
|
|
|
|
|
|
|
transa = -1; |
|
|
|
|
|
transb = -1; |
|
|
|
|
|
uplo = -1; |
|
|
|
|
|
|
|
|
|
|
|
if (transA == 'N') |
|
|
|
|
|
transa = 0; |
|
|
|
|
|
if (transA == 'T') |
|
|
|
|
|
transa = 1; |
|
|
|
|
|
|
|
|
|
|
|
if (transA == 'R') |
|
|
|
|
|
transa = 0; |
|
|
|
|
|
if (transA == 'C') |
|
|
|
|
|
transa = 1; |
|
|
|
|
|
|
|
|
|
|
|
if (transB == 'N') |
|
|
|
|
|
transb = 0; |
|
|
|
|
|
if (transB == 'T') |
|
|
|
|
|
transb = 1; |
|
|
|
|
|
|
|
|
|
|
|
if (transB == 'R') |
|
|
|
|
|
transb = 0; |
|
|
|
|
|
if (transB == 'C') |
|
|
|
|
|
transb = 1; |
|
|
|
|
|
|
|
|
|
|
|
if (Uplo == 'U') |
|
|
|
|
|
uplo = 0; |
|
|
|
|
|
if (Uplo == 'L') |
|
|
|
|
|
uplo = 1; |
|
|
|
|
|
nrowa = m; |
|
|
|
|
|
if (transa & 1) nrowa = k; |
|
|
|
|
|
nrowb = k; |
|
|
|
|
|
if (transb & 1) nrowb = m; |
|
|
|
|
|
|
|
|
|
|
|
info = 0; |
|
|
|
|
|
|
|
|
|
|
|
if (ldc < MAX(1, m)) |
|
|
|
|
|
info = 13; |
|
|
|
|
|
if (ldb < MAX(1, nrowb)) |
|
|
|
|
|
info = 10; |
|
|
|
|
|
if (lda < MAX(1, nrowa)) |
|
|
|
|
|
info = 8; |
|
|
|
|
|
if (k < 0) |
|
|
|
|
|
info = 5; |
|
|
|
|
|
if (m < 0) |
|
|
|
|
|
info = 4; |
|
|
|
|
|
if (transb < 0) |
|
|
|
|
|
info = 3; |
|
|
|
|
|
if (transa < 0) |
|
|
|
|
|
info = 2; |
|
|
|
|
|
if (uplo < 0) |
|
|
|
|
|
info = 1; |
|
|
|
|
|
|
|
|
|
|
|
if (info != 0) { |
|
|
|
|
|
BLASFUNC(xerbla) (ERROR_NAME, &info, sizeof(ERROR_NAME)); |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
|
|
|
#else |
|
|
|
|
|
|
|
|
|
|
|
void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, |
|
|
|
|
|
enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, blasint m, |
|
|
|
|
|
blasint k, |
|
|
|
|
|
FLOAT alpha, |
|
|
|
|
|
IFLOAT * A, blasint LDA, |
|
|
|
|
|
IFLOAT * B, blasint LDB, FLOAT beta, FLOAT * c, blasint ldc) |
|
|
|
|
|
{ |
|
|
|
|
|
IFLOAT *aa, *bb; |
|
|
|
|
|
FLOAT *cc; |
|
|
|
|
|
|
|
|
|
|
|
int transa, transb, uplo; |
|
|
|
|
|
blasint info; |
|
|
|
|
|
blasint lda, ldb; |
|
|
|
|
|
IFLOAT *a, *b; |
|
|
|
|
|
XFLOAT *buffer; |
|
|
|
|
|
|
|
|
|
|
|
PRINT_DEBUG_CNAME; |
|
|
|
|
|
|
|
|
|
|
|
uplo = -1; |
|
|
|
|
|
transa = -1; |
|
|
|
|
|
transb = -1; |
|
|
|
|
|
info = 0; |
|
|
|
|
|
|
|
|
|
|
|
if (order == CblasColMajor) { |
|
|
|
|
|
if (Uplo == CblasUpper) uplo = 0; |
|
|
|
|
|
if (Uplo == CblasLower) uplo = 1; |
|
|
|
|
|
|
|
|
|
|
|
if (TransA == CblasNoTrans) |
|
|
|
|
|
transa = 0; |
|
|
|
|
|
if (TransA == CblasTrans) |
|
|
|
|
|
transa = 1; |
|
|
|
|
|
|
|
|
|
|
|
if (TransA == CblasConjNoTrans) |
|
|
|
|
|
transa = 0; |
|
|
|
|
|
if (TransA == CblasConjTrans) |
|
|
|
|
|
transa = 1; |
|
|
|
|
|
|
|
|
|
|
|
if (TransB == CblasNoTrans) |
|
|
|
|
|
transb = 0; |
|
|
|
|
|
if (TransB == CblasTrans) |
|
|
|
|
|
transb = 1; |
|
|
|
|
|
|
|
|
|
|
|
if (TransB == CblasConjNoTrans) |
|
|
|
|
|
transb = 0; |
|
|
|
|
|
if (TransB == CblasConjTrans) |
|
|
|
|
|
transb = 1; |
|
|
|
|
|
|
|
|
|
|
|
a = (void *)A; |
|
|
|
|
|
b = (void *)B; |
|
|
|
|
|
lda = LDA; |
|
|
|
|
|
ldb = LDB; |
|
|
|
|
|
|
|
|
|
|
|
info = -1; |
|
|
|
|
|
|
|
|
|
|
|
blasint nrowa; |
|
|
|
|
|
blasint nrowb; |
|
|
|
|
|
nrowa = m; |
|
|
|
|
|
if (transa & 1) nrowa = k; |
|
|
|
|
|
nrowb = k; |
|
|
|
|
|
if (transb & 1) nrowb = m; |
|
|
|
|
|
|
|
|
|
|
|
if (ldc < MAX(1, m)) |
|
|
|
|
|
info = 13; |
|
|
|
|
|
if (ldb < MAX(1, nrowb)) |
|
|
|
|
|
info = 10; |
|
|
|
|
|
if (lda < MAX(1, nrowa)) |
|
|
|
|
|
info = 8; |
|
|
|
|
|
if (k < 0) |
|
|
|
|
|
info = 5; |
|
|
|
|
|
if (m < 0) |
|
|
|
|
|
info = 4; |
|
|
|
|
|
if (transb < 0) |
|
|
|
|
|
info = 3; |
|
|
|
|
|
if (transa < 0) |
|
|
|
|
|
info = 2; |
|
|
|
|
|
if (uplo < 0) |
|
|
|
|
|
info = 1; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (order == CblasRowMajor) { |
|
|
|
|
|
|
|
|
|
|
|
a = (void *)B; |
|
|
|
|
|
b = (void *)A; |
|
|
|
|
|
|
|
|
|
|
|
lda = LDB; |
|
|
|
|
|
ldb = LDA; |
|
|
|
|
|
|
|
|
|
|
|
if (Uplo == CblasUpper) uplo = 0; |
|
|
|
|
|
if (Uplo == CblasLower) uplo = 1; |
|
|
|
|
|
|
|
|
|
|
|
if (TransB == CblasNoTrans) |
|
|
|
|
|
transa = 0; |
|
|
|
|
|
if (TransB == CblasTrans) |
|
|
|
|
|
transa = 1; |
|
|
|
|
|
|
|
|
|
|
|
if (TransB == CblasConjNoTrans) |
|
|
|
|
|
transa = 0; |
|
|
|
|
|
if (TransB == CblasConjTrans) |
|
|
|
|
|
transa = 1; |
|
|
|
|
|
|
|
|
|
|
|
if (TransA == CblasNoTrans) |
|
|
|
|
|
transb = 0; |
|
|
|
|
|
if (TransA == CblasTrans) |
|
|
|
|
|
transb = 1; |
|
|
|
|
|
|
|
|
|
|
|
if (TransA == CblasConjNoTrans) |
|
|
|
|
|
transb = 0; |
|
|
|
|
|
if (TransA == CblasConjTrans) |
|
|
|
|
|
transb = 1; |
|
|
|
|
|
|
|
|
|
|
|
info = -1; |
|
|
|
|
|
|
|
|
|
|
|
blasint ncola; |
|
|
|
|
|
blasint ncolb; |
|
|
|
|
|
|
|
|
|
|
|
ncola = m; |
|
|
|
|
|
if (transa & 1) ncola = k; |
|
|
|
|
|
ncolb = k; |
|
|
|
|
|
|
|
|
|
|
|
if (transb & 1) { |
|
|
|
|
|
ncolb = m; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (ldc < MAX(1,m)) |
|
|
|
|
|
info = 13; |
|
|
|
|
|
if (ldb < MAX(1, ncolb)) |
|
|
|
|
|
info = 8; |
|
|
|
|
|
if (lda < MAX(1, ncola)) |
|
|
|
|
|
info = 10; |
|
|
|
|
|
if (k < 0) |
|
|
|
|
|
info = 5; |
|
|
|
|
|
if (m < 0) |
|
|
|
|
|
info = 4; |
|
|
|
|
|
if (transb < 0) |
|
|
|
|
|
info = 2; |
|
|
|
|
|
if (transa < 0) |
|
|
|
|
|
info = 3; |
|
|
|
|
|
if (uplo < 0) |
|
|
|
|
|
info = 1; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (info >= 0) { |
|
|
|
|
|
BLASFUNC(xerbla) (ERROR_NAME, &info, sizeof(ERROR_NAME)); |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#endif |
|
|
|
|
|
int buffer_size; |
|
|
|
|
|
blasint i, j; |
|
|
|
|
|
|
|
|
|
|
|
#ifdef SMP |
|
|
|
|
|
int nthreads; |
|
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef SMP |
|
|
|
|
|
static int (*gemv_thread[]) (BLASLONG, BLASLONG, FLOAT, IFLOAT *, |
|
|
|
|
|
BLASLONG, IFLOAT *, BLASLONG, FLOAT, |
|
|
|
|
|
FLOAT *, BLASLONG, int) = { |
|
|
|
|
|
sbgemv_thread_n, sbgemv_thread_t, |
|
|
|
|
|
}; |
|
|
|
|
|
#endif |
|
|
|
|
|
int (*gemv[]) (BLASLONG, BLASLONG, FLOAT, IFLOAT *, BLASLONG, |
|
|
|
|
|
IFLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG) = { |
|
|
|
|
|
SBGEMV_N, SBGEMV_T,}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (m == 0) |
|
|
|
|
|
return; |
|
|
|
|
|
|
|
|
|
|
|
IDEBUG_START; |
|
|
|
|
|
|
|
|
|
|
|
const blasint incb = ((transb & 1) == 0) ? 1 : ldb; |
|
|
|
|
|
|
|
|
|
|
|
if (uplo == 1) { |
|
|
|
|
|
for (i = 0; i < m; i++) { |
|
|
|
|
|
j = m - i; |
|
|
|
|
|
|
|
|
|
|
|
aa = a + i; |
|
|
|
|
|
bb = b + i * ldb; |
|
|
|
|
|
if (transa & 1) { |
|
|
|
|
|
aa = a + lda * i; |
|
|
|
|
|
} |
|
|
|
|
|
if (transb & 1) |
|
|
|
|
|
bb = b + i; |
|
|
|
|
|
cc = c + i * ldc + i; |
|
|
|
|
|
|
|
|
|
|
|
#if 0 |
|
|
|
|
|
if (beta != ONE) |
|
|
|
|
|
SCAL_K(l, 0, 0, beta, cc, 1, NULL, 0, NULL, 0); |
|
|
|
|
|
|
|
|
|
|
|
if (alpha == ZERO) |
|
|
|
|
|
continue; |
|
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
IDEBUG_START; |
|
|
|
|
|
|
|
|
|
|
|
buffer_size = j + k + 128 / sizeof(FLOAT); |
|
|
|
|
|
#ifdef WINDOWS_ABI |
|
|
|
|
|
buffer_size += 160 / sizeof(FLOAT); |
|
|
|
|
|
#endif |
|
|
|
|
|
// for alignment |
|
|
|
|
|
buffer_size = (buffer_size + 3) & ~3; |
|
|
|
|
|
STACK_ALLOC(buffer_size, IFLOAT, buffer); |
|
|
|
|
|
|
|
|
|
|
|
#ifdef SMP |
|
|
|
|
|
|
|
|
|
|
|
if (1L * j * k < 2304L * GEMM_MULTITHREAD_THRESHOLD) |
|
|
|
|
|
nthreads = 1; |
|
|
|
|
|
else |
|
|
|
|
|
nthreads = num_cpu_avail(2); |
|
|
|
|
|
|
|
|
|
|
|
if (nthreads == 1) { |
|
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
if (!(transa & 1)) |
|
|
|
|
|
(gemv[(int)transa]) (j, k, alpha, aa, lda, |
|
|
|
|
|
bb, incb, beta, cc, 1); |
|
|
|
|
|
else |
|
|
|
|
|
(gemv[(int)transa]) (k, j, alpha, aa, lda, |
|
|
|
|
|
bb, incb, beta, cc, 1); |
|
|
|
|
|
|
|
|
|
|
|
#ifdef SMP |
|
|
|
|
|
} else { |
|
|
|
|
|
if (!(transa & 1)) |
|
|
|
|
|
(gemv_thread[(int)transa]) (j, k, alpha, aa, |
|
|
|
|
|
lda, bb, incb, beta, cc, |
|
|
|
|
|
1, nthreads); |
|
|
|
|
|
else |
|
|
|
|
|
(gemv_thread[(int)transa]) (k, j, alpha, aa, |
|
|
|
|
|
lda, bb, incb, beta, cc, |
|
|
|
|
|
1, nthreads); |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
STACK_FREE(buffer); |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
|
|
|
|
|
|
for (i = 0; i < m; i++) { |
|
|
|
|
|
j = i + 1; |
|
|
|
|
|
|
|
|
|
|
|
bb = b + i * ldb; |
|
|
|
|
|
if (transb & 1) { |
|
|
|
|
|
bb = b + i; |
|
|
|
|
|
} |
|
|
|
|
|
cc = c + i * ldc; |
|
|
|
|
|
|
|
|
|
|
|
#if 0 |
|
|
|
|
|
if (beta != ONE) |
|
|
|
|
|
SCAL_K(l, 0, 0, beta, cc, 1, NULL, 0, NULL, 0); |
|
|
|
|
|
|
|
|
|
|
|
if (alpha == ZERO) |
|
|
|
|
|
continue; |
|
|
|
|
|
#endif |
|
|
|
|
|
IDEBUG_START; |
|
|
|
|
|
|
|
|
|
|
|
buffer_size = j + k + 128 / sizeof(FLOAT); |
|
|
|
|
|
#ifdef WINDOWS_ABI |
|
|
|
|
|
buffer_size += 160 / sizeof(FLOAT); |
|
|
|
|
|
#endif |
|
|
|
|
|
// for alignment |
|
|
|
|
|
buffer_size = (buffer_size + 3) & ~3; |
|
|
|
|
|
STACK_ALLOC(buffer_size, IFLOAT, buffer); |
|
|
|
|
|
|
|
|
|
|
|
#ifdef SMP |
|
|
|
|
|
|
|
|
|
|
|
if (1L * j * k < 2304L * GEMM_MULTITHREAD_THRESHOLD) |
|
|
|
|
|
nthreads = 1; |
|
|
|
|
|
else |
|
|
|
|
|
nthreads = num_cpu_avail(2); |
|
|
|
|
|
|
|
|
|
|
|
if (nthreads == 1) { |
|
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
if (!(transa & 1)) |
|
|
|
|
|
(gemv[(int)transa]) (j, k, alpha, a, lda, bb, |
|
|
|
|
|
incb, beta, cc, 1); |
|
|
|
|
|
else |
|
|
|
|
|
(gemv[(int)transa]) (k, j, alpha, a, lda, bb, |
|
|
|
|
|
incb, beta, cc, 1); |
|
|
|
|
|
|
|
|
|
|
|
#ifdef SMP |
|
|
|
|
|
} else { |
|
|
|
|
|
if (!(transa & 1)) |
|
|
|
|
|
(gemv_thread[(int)transa]) (j, k, alpha, a, lda, |
|
|
|
|
|
bb, incb, beta, cc, 1, |
|
|
|
|
|
nthreads); |
|
|
|
|
|
else |
|
|
|
|
|
(gemv_thread[(int)transa]) (k, j, alpha, a, lda, |
|
|
|
|
|
bb, incb, beta, cc, 1, |
|
|
|
|
|
nthreads); |
|
|
|
|
|
} |
|
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
STACK_FREE(buffer); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
IDEBUG_END; |
|
|
|
|
|
|
|
|
|
|
|
return; |
|
|
|
|
|
} |