Skip to content

Conversation

@LeiWang1999
Copy link
Contributor

now llvm declares rocm arch in a different way, as we can observe that gfx90a/90c is not an integer only arch name which will cause some issues when generate code under the latest amd gpus:

 ``EF_AMDGPU_MACH_AMDGCN_GFX801``     0x028      ``gfx801``
 ``EF_AMDGPU_MACH_AMDGCN_GFX802``     0x029      ``gfx802``
 ``EF_AMDGPU_MACH_AMDGCN_GFX803``     0x02a      ``gfx803``
 ``EF_AMDGPU_MACH_AMDGCN_GFX810``     0x02b      ``gfx810``
 ``EF_AMDGPU_MACH_AMDGCN_GFX900``     0x02c      ``gfx900``
 ``EF_AMDGPU_MACH_AMDGCN_GFX902``     0x02d      ``gfx902``
 ``EF_AMDGPU_MACH_AMDGCN_GFX904``     0x02e      ``gfx904``
 ``EF_AMDGPU_MACH_AMDGCN_GFX906``     0x02f      ``gfx906``
 ``EF_AMDGPU_MACH_AMDGCN_GFX908``     0x030      ``gfx908``
 ``EF_AMDGPU_MACH_AMDGCN_GFX909``     0x031      ``gfx909``
 ``EF_AMDGPU_MACH_AMDGCN_GFX90C``     0x032      ``gfx90c``
 ``EF_AMDGPU_MACH_AMDGCN_GFX1010``    0x033      ``gfx1010``
 ``EF_AMDGPU_MACH_AMDGCN_GFX1011``    0x034      ``gfx1011``
 ``EF_AMDGPU_MACH_AMDGCN_GFX1012``    0x035      ``gfx1012``
 ``EF_AMDGPU_MACH_AMDGCN_GFX1030``    0x036      ``gfx1030``
 ``EF_AMDGPU_MACH_AMDGCN_GFX1031``    0x037      ``gfx1031``
 ``EF_AMDGPU_MACH_AMDGCN_GFX1032``    0x038      ``gfx1032``
 ``EF_AMDGPU_MACH_AMDGCN_GFX1033``    0x039      ``gfx1033``
 ``EF_AMDGPU_MACH_AMDGCN_GFX602``     0x03a      ``gfx602``
 ``EF_AMDGPU_MACH_AMDGCN_GFX705``     0x03b      ``gfx705``
 ``EF_AMDGPU_MACH_AMDGCN_GFX805``     0x03c      ``gfx805``
 ``EF_AMDGPU_MACH_AMDGCN_GFX1035``    0x03d      ``gfx1035``
 ``EF_AMDGPU_MACH_AMDGCN_GFX1034``    0x03e      ``gfx1034``
 ``EF_AMDGPU_MACH_AMDGCN_GFX90A``     0x03f      ``gfx90a``
 ``EF_AMDGPU_MACH_AMDGCN_GFX940``     0x040      ``gfx940``
 ``EF_AMDGPU_MACH_AMDGCN_GFX1100``    0x041      ``gfx1100``
 ``EF_AMDGPU_MACH_AMDGCN_GFX1013``    0x042      ``gfx1013``

for example, on MI200, the codegen will failed when we specified arch to gfx90a.

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle):
        T.func_attr({"global_symbol": "main"})
        A = T.match_buffer(a, (M, N), dtype="float32")
        B = T.match_buffer(b, (M, N), dtype="float32")
        for i, j in T.grid(M, N):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] * 2.0
ir_module = MyModule
sch = tvm.tir.Schedule(ir_module, debug_mask="all")
block_b = sch.get_block("B")
i, j = sch.get_loops(block_b)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.x")
with tvm.transform.PassContext():
    rocm_mod = tvm.build(sch.mod, target="rocm -mcpu=gfx90a")

enabling this fix will figure it out.

@tvm-bot
Copy link
Collaborator

tvm-bot commented Jun 13, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

  • No users to tag found in teams: target, rocm See #10317 for details

Generated by tvm-bot

@LeiWang1999
Copy link
Contributor Author

@vinx13 @masahi cc please.

@masahi masahi merged commit 02136b3 into apache:main Jun 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants