diff --git a/examples/matmul/run.cpp b/examples/matmul/run.cpp index 35d0f09..4e61968 100644 --- a/examples/matmul/run.cpp +++ b/examples/matmul/run.cpp @@ -228,19 +228,6 @@ inline KernelCode createMatmul3(const char *shaderTemplate, const size_t M, } } -/** - * @brief No-Op shader with matmul bindings for performance testing - */ -static const char *kShaderNoOp = R"( -@group(0) @binding(0) var A: array<{{precision}}>; -@group(0) @binding(1) var B: array<{{precision}}>; -@group(0) @binding(2) var C: array<{{precision}}>; -@compute @workgroup_size({{workgroupSize}}) -fn main( - @builtin(global_invocation_id) globalID : vec3) { -} -)"; - /* 2D block-tiling * */ @@ -357,6 +344,141 @@ inline KernelCode createMatmul4(const char *shaderTemplate, const size_t M, } } +/* 2D block-tiling with vectorization + * + */ +static const char *kShaderMatmulWithVectorization = R"( +@group(0) @binding(0) var a: array<{{precision}}>; +@group(0) @binding(1) var b: array<{{precision}}>; +@group(0) @binding(2) var c: array>; +var tileA: array<{{precision}}, {{BM}} * {{BK}}>; +var tileB: array<{{precision}}, {{BN}} * {{BK}}>; + +@compute @workgroup_size({{workgroupSize}}) +fn main( + @builtin(global_invocation_id) globalID : vec3, + @builtin(local_invocation_id) localID : vec3, + @builtin(workgroup_id) groupid : vec3) { + + var threadResults: array, {{TM}} * {{TN4}}>; + var localM: array<{{precision}}, {{TM}}>; + var localN: array, {{TN4}}>; + + let cRow: u32 = groupid.x; + let cCol: u32 = groupid.y; + let numThread: u32 = ({{BM}} * {{BN}}) / ({{TM}} * {{TN}}); + + // position of the first c element computed by the thread + let threadRow: u32 = (localID.x / ({{BN}} / {{TN}})) * {{TM}}; + let threadCol: u32 = (localID.x % ({{BN}} / {{TN}})) * {{TN}}; + + // aPtr and bPtr are the starting positions of the tiles in a and b, + // incremented in the bkidx loop. + // cPtr is the starting position of the tile in c which is fixed. + + var aPtr = cRow * {{BM}} * {{K}}; + var bPtr = cCol * {{BN}} * {{K}}; + let cPtr = cRow * {{BM}} * {{N4}} + cCol * {{BN4}}; + + for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) { + + // Load tile + // Load BM x BK by numThread(BM * BN / (TM * TN)) + // The number of iteration == BM * BK / (BM * BN / (TM * TN)) + for (var idx: u32 = 0; idx < {{NUM_TILEA}}; idx++) { + tileA[localID.x + idx * numThread] = a[aPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + (localID.x + idx * numThread) % {{BK}}]; + } + // Load BK x BN by numThread(BM * BN / (TM * TN)) + // The number of iteration == BK * BN / (BM * BN / (TM * TN)) + for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) { + tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})]; + } + + aPtr += {{BK}}; + bPtr += {{BK}}; + + workgroupBarrier(); + // Compute tile + for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) { + for (var idx: u32 = 0; idx < {{TM}}; idx++) { + localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx]; + } + for (var idx: u32 = 0; idx < {{TN4}}; idx++) { + localN[idx] = vec4<{{precision}}>(tileB[(threadCol + idx*4 ) * {{BK}} + dotIdx], + tileB[(threadCol + idx*4 + 1) * {{BK}} + dotIdx], + tileB[(threadCol + idx*4 + 2) * {{BK}} + dotIdx], + tileB[(threadCol + idx*4 + 3) * {{BK}} + dotIdx]); + } + for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) { + for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) { + threadResults[resIdxM * {{TN4}} + resIdxN] += localM[resIdxM] * localN[resIdxN]; + } + } + } + workgroupBarrier(); + } + + for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) { + for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) { + c[cPtr + (threadRow + resIdxM) * {{N4}} + (threadCol/4) + resIdxN] = threadResults[resIdxM * {{TN4}} + resIdxN]; + } + } +} +)"; + +inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, const size_t M, + const size_t K, const size_t N, const size_t BM, + const size_t BK, const size_t BN, + const size_t TM, const size_t TN, + const Shape &workgroupSize = {256, 1, 1}, + NumType precision = kf32, + bool unrolling = false) { + assert(BM % TM == 0); + assert(BN % TN == 0); + assert(K % BK == 0); + assert(M % BM == 0); + assert(N % BN == 0); + // # threads = tile A size == tile B size == # threads for computing C + int num_threads = BM * BN / (TM * TN); + std::string codeString(shaderTemplate); + replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)}, + {"{{precision}}", toString(precision)}, + {"{{M}}", toString(M)}, + {"{{K}}", toString(K)}, + {"{{N}}", toString(N)}, + {"{{BM}}", toString(BM)}, + {"{{BK}}", toString(BK)}, + {"{{BN}}", toString(BN)}, + {"{{TM}}", toString(TM)}, + {"{{TN}}", toString(TN)}, + {"{{NUM_TILEA}}", toString(BM * BK / num_threads)}, + {"{{NUM_TILEB}}", toString(BN * BK / num_threads)}, + {"{{TN4}}", toString(TN / 4)}, + {"{{N4}}", toString(N / 4)}, + {"{{BN4}}", toString(BN / 4)}, + }); + if (unrolling) { + std::string unrolledCode = loopUnrolling(codeString); + LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str()); + return {unrolledCode, workgroupSize}; + } else { + return {codeString, workgroupSize}; + } +} + +/** + * @brief No-Op shader with matmul bindings for performance testing + */ +static const char *kShaderNoOp = R"( +@group(0) @binding(0) var A: array<{{precision}}>; +@group(0) @binding(1) var B: array<{{precision}}>; +@group(0) @binding(2) var C: array<{{precision}}>; +@compute @workgroup_size({{workgroupSize}}) +fn main( + @builtin(global_invocation_id) globalID : vec3) { +} +)"; + inline KernelCode createNoOp(const char *shaderTemplate, const Shape &workgroupSize = {256, 1, 1}, NumType precision = kf32) { @@ -448,6 +570,24 @@ Kernel selectMatmul(Context &ctx, int version, kernel = createKernel(ctx, matmul, bindings, /*nWorkgroups*/ nWorkgroups); } else if (version == 7) { + static constexpr size_t BM = 64; + static constexpr size_t BK = 16; + static constexpr size_t BN = 64; + static constexpr size_t TM = BM / BK; + static constexpr size_t TN = BN / BK; + Shape wgSize = {(BM / TM) * (BN / TN), 1, 1}; // This is the same as BK * BK. + Shape nWorkgroups = {cdiv(M, BM), cdiv(N, BN), 1}; + LOG(kDefLog, kInfo, "M: %d, K: %d, N: %d", M, K, N); + LOG(kDefLog, kInfo, "BM: %d, BK: %d, BN: %d, TM: %d, TN: %d", BM, BK, BN, TM, TN); + LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str()); + LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str()); + KernelCode matmul = createMatmulWithVectorization(kShaderMatmulWithVectorization, M, K, N, BM, BK, BN, TM, TN, + /*wgSize*/ wgSize, + kf32, + /*Loop unrolling*/ true); + kernel = createKernel(ctx, matmul, bindings, + /*nWorkgroups*/ nWorkgroups); + } else if (version == 8) { Shape wgSize = {256, 1, 1}; Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1}); KernelCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize); @@ -528,7 +668,8 @@ int main() { // 4 == 2D blocktiling // 5 == 1D blocktiling with loop unrolling // 6 == 2D blocktiling with loop unrolling - // 7 == No-Op + // 7 == 2D blocktiling with loop unrolling and vectorization + // 8 == No-Op size_t M, K, N; // Matrix dimensions static constexpr int kTestSize = 2;