Browse Source

Fix ?GEMMT

tags/v0.3.24
Martin Kroeker GitHub 2 years ago
parent
commit
38d7a7b562
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 50 additions and 50 deletions
  1. +50
    -50
      interface/gemmt.c

+ 50
- 50
interface/gemmt.c View File

@@ -35,29 +35,26 @@
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include "common.h" #include "common.h"
#ifdef FUNCTION_PROFILE
#include "functable.h"
#endif


#ifndef COMPLEX #ifndef COMPLEX
#define SMP_THRESHOLD_MIN 65536.0 #define SMP_THRESHOLD_MIN 65536.0
#ifdef XDOUBLE #ifdef XDOUBLE
#define ERROR_NAME "QGEMT "
#define ERROR_NAME "QGEMMT "
#elif defined(DOUBLE) #elif defined(DOUBLE)
#define ERROR_NAME "DGEMT "
#define ERROR_NAME "DGEMMT "
#elif defined(BFLOAT16) #elif defined(BFLOAT16)
#define ERROR_NAME "SBGEMT "
#define ERROR_NAME "SBGEMMT "
#else #else
#define ERROR_NAME "SGEMT "
#define ERROR_NAME "SGEMMT "
#endif #endif
#else #else
#define SMP_THRESHOLD_MIN 8192.0 #define SMP_THRESHOLD_MIN 8192.0
#ifdef XDOUBLE #ifdef XDOUBLE
#define ERROR_NAME "XGEMT "
#define ERROR_NAME "XGEMMT "
#elif defined(DOUBLE) #elif defined(DOUBLE)
#define ERROR_NAME "ZGEMT "
#define ERROR_NAME "ZGEMMT "
#else #else
#define ERROR_NAME "CGEMT "
#define ERROR_NAME "CGEMMT "
#endif #endif
#endif #endif


@@ -68,13 +65,13 @@
#ifndef CBLAS #ifndef CBLAS


void NAME(char *UPLO, char *TRANSA, char *TRANSB, void NAME(char *UPLO, char *TRANSA, char *TRANSB,
blasint * M, blasint * N, blasint * K,
blasint * M, blasint * K,
FLOAT * Alpha, FLOAT * Alpha,
IFLOAT * a, blasint * ldA, IFLOAT * a, blasint * ldA,
IFLOAT * b, blasint * ldB, FLOAT * Beta, FLOAT * c, blasint * ldC) IFLOAT * b, blasint * ldB, FLOAT * Beta, FLOAT * c, blasint * ldC)
{ {


blasint m, n, k;
blasint m, k;
blasint lda, ldb, ldc; blasint lda, ldb, ldc;
int transa, transb, uplo; int transa, transb, uplo;
blasint info; blasint info;
@@ -92,7 +89,6 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,
PRINT_DEBUG_NAME; PRINT_DEBUG_NAME;


m = *M; m = *M;
n = *N;
k = *K; k = *K;


#if defined(COMPLEX) #if defined(COMPLEX)
@@ -167,8 +163,6 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,
info = 13; info = 13;
if (k < 0) if (k < 0)
info = 5; info = 5;
if (n < 0)
info = 4;
if (m < 0) if (m < 0)
info = 3; info = 3;
if (transb < 0) if (transb < 0)
@@ -184,7 +178,7 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,


void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, blasint M, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, blasint M,
blasint N, blasint k,
blasint k,
#ifndef COMPLEX #ifndef COMPLEX
FLOAT alpha, FLOAT alpha,
IFLOAT * A, blasint LDA, IFLOAT * A, blasint LDA,
@@ -205,7 +199,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,


int transa, transb, uplo; int transa, transb, uplo;
blasint info; blasint info;
blasint m, n, lda, ldb;
blasint m, lda, ldb;
FLOAT *a, *b; FLOAT *a, *b;
XFLOAT *buffer; XFLOAT *buffer;


@@ -248,9 +242,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
transb = 3; transb = 3;
#endif #endif


m = M;
n = N;

a = (void *)A; a = (void *)A;
b = (void *)B; b = (void *)B;
lda = LDA; lda = LDA;
@@ -262,8 +253,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
info = 13; info = 13;
if (k < 0) if (k < 0)
info = 5; info = 5;
if (n < 0)
info = 4;
if (m < 0) if (m < 0)
info = 3; info = 3;
if (transb < 0) if (transb < 0)
@@ -273,8 +262,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
} }


if (order == CblasRowMajor) { if (order == CblasRowMajor) {
m = N;
n = M;


a = (void *)B; a = (void *)B;
b = (void *)A; b = (void *)A;
@@ -319,8 +306,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
info = 13; info = 13;
if (k < 0) if (k < 0)
info = 5; info = 5;
if (n < 0)
info = 4;
if (m < 0) if (m < 0)
info = 3; info = 3;
if (transb < 0) if (transb < 0)
@@ -407,37 +392,35 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,


#endif #endif


if ((m == 0) || (n == 0))
if ((m == 0) )
return; return;


IDEBUG_START; IDEBUG_START;


FUNCTION_PROFILE_START();

const blasint incb = (transb == 0) ? 1 : ldb; const blasint incb = (transb == 0) ? 1 : ldb;


if (uplo == 1) { if (uplo == 1) {
for (i = 0; i < n; i++) {
j = n - i;
for (i = 0; i < m; i++) {
j = m - i;


l = j; l = j;
#if defined(COMPLEX) #if defined(COMPLEX)
aa = a + i * 2; aa = a + i * 2;
bb = b + i * ldb * 2; bb = b + i * ldb * 2;
if (transa) { if (transa) {
l = k;
aa = a + lda * i * 2; aa = a + lda * i * 2;
bb = b + i * 2;
} }
if (transb)
bb = b + i * 2;
cc = c + i * 2 * ldc + i * 2; cc = c + i * 2 * ldc + i * 2;
#else #else
aa = a + i; aa = a + i;
bb = b + i * ldb; bb = b + i * ldb;
if (transa) { if (transa) {
l = k;
aa = a + lda * i; aa = a + lda * i;
bb = b + i;
} }
if (transb)
bb = b + i;
cc = c + i * ldc + i; cc = c + i * ldc + i;
#endif #endif


@@ -458,8 +441,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,


IDEBUG_START; IDEBUG_START;


FUNCTION_PROFILE_START();

buffer_size = j + k + 128 / sizeof(FLOAT); buffer_size = j + k + 128 / sizeof(FLOAT);
#ifdef WINDOWS_ABI #ifdef WINDOWS_ABI
buffer_size += 160 / sizeof(FLOAT); buffer_size += 160 / sizeof(FLOAT);
@@ -479,20 +460,34 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#endif #endif


#if defined(COMPLEX) #if defined(COMPLEX)
if (!transa)
(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i, (gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i,
aa, lda, bb, incb, cc, 1, aa, lda, bb, incb, cc, 1,
buffer); buffer);
else
(gemv[(int)transa]) (k, j, 0, alpha_r, alpha_i,
aa, lda, bb, incb, cc, 1,
buffer);
#else #else
if (!transa)
(gemv[(int)transa]) (j, k, 0, alpha, aa, lda, (gemv[(int)transa]) (j, k, 0, alpha, aa, lda,
bb, incb, cc, 1, buffer); bb, incb, cc, 1, buffer);
else
(gemv[(int)transa]) (k, j, 0, alpha, aa, lda,
bb, incb, cc, 1, buffer);
#endif #endif
#ifdef SMP #ifdef SMP
} else { } else {
if (!transa)
(gemv_thread[(int)transa]) (j, k, alpha, aa, (gemv_thread[(int)transa]) (j, k, alpha, aa,
lda, bb, incb, cc, lda, bb, incb, cc,
1, buffer, 1, buffer,
nthreads); nthreads);
else
(gemv_thread[(int)transa]) (k, j, alpha, aa,
lda, bb, incb, cc,
1, buffer,
nthreads);


} }
#endif #endif
@@ -501,21 +496,19 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
} }
} else { } else {


for (i = 0; i < n; i++) {
for (i = 0; i < m; i++) {
j = i + 1; j = i + 1;


l = j; l = j;
#if defined COMPLEX #if defined COMPLEX
bb = b + i * ldb * 2; bb = b + i * ldb * 2;
if (transa) {
l = k;
if (transb) {
bb = b + i * 2; bb = b + i * 2;
} }
cc = c + i * 2 * ldc; cc = c + i * 2 * ldc;
#else #else
bb = b + i * ldb; bb = b + i * ldb;
if (transa) {
l = k;
if (transb) {
bb = b + i; bb = b + i;
} }
cc = c + i * ldc; cc = c + i * ldc;
@@ -537,8 +530,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#endif #endif
IDEBUG_START; IDEBUG_START;


FUNCTION_PROFILE_START();

buffer_size = j + k + 128 / sizeof(FLOAT); buffer_size = j + k + 128 / sizeof(FLOAT);
#ifdef WINDOWS_ABI #ifdef WINDOWS_ABI
buffer_size += 160 / sizeof(FLOAT); buffer_size += 160 / sizeof(FLOAT);
@@ -558,30 +549,39 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#endif #endif


#if defined(COMPLEX) #if defined(COMPLEX)
if (!transa)
(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i, (gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i,
a, lda, bb, incb, cc, 1, a, lda, bb, incb, cc, 1,
buffer); buffer);
else
(gemv[(int)transa]) (k, j, 0, alpha_r, alpha_i,
a, lda, bb, incb, cc, 1,
buffer);
#else #else
if (!transa)
(gemv[(int)transa]) (j, k, 0, alpha, a, lda, bb, (gemv[(int)transa]) (j, k, 0, alpha, a, lda, bb,
incb, cc, 1, buffer); incb, cc, 1, buffer);
else
(gemv[(int)transa]) (k, j, 0, alpha, a, lda, bb,
incb, cc, 1, buffer);
#endif #endif


#ifdef SMP #ifdef SMP
} else { } else {
if (!transa)
(gemv_thread[(int)transa]) (j, k, alpha, a, lda, (gemv_thread[(int)transa]) (j, k, alpha, a, lda,
bb, incb, cc, 1, bb, incb, cc, 1,
buffer, nthreads); buffer, nthreads);

else
(gemv_thread[(int)transa]) (k, j, alpha, a, lda,
bb, incb, cc, 1,
buffer, nthreads);
} }
#endif #endif


STACK_FREE(buffer); STACK_FREE(buffer);
} }
} }
FUNCTION_PROFILE_END(COMPSIZE * COMPSIZE,
args.m * args.k + args.k * args.n +
args.m * args.n, 2 * args.m * args.n * args.k);


IDEBUG_END; IDEBUG_END;




Loading…
Cancel
Save