Skip to content

Tune whether to use mm or bmm for matmul in inductor max-autotune #118774

@soulitzer

Description

@soulitzer

[...] but this is something we can certainly tune in torch.compile with max-autotune. cc @ptrblck @msaroufim @eqy @jerryzh168 @csarofeen @xwang233 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov @ezyang @bdhirsh @anijain2305 @zou3519 @peterbell10 @yf225 @ColinPeppler @eellison @shunting314

Originally posted by @lezcano in #118548 (comment)

We can observe a significant difference in runtime for matmul depending on whether requires_grad=True because depending on requires_grad-ness, we will either call bmm and mm underneath.

Some results based on input sizes in the form of [768, 768] @ [B, 768, M] with fp16 on A100 (runtime in ms):

B          M            mm + copy          bmm
---------------------------------------------------
2048       4            0.152385           1.432394
1024       8            0.147818           0.488963
512        64           0.692234           0.352137
128        128          0.390359           0.129712
64         512          0.773324           0.221456
8          1024         0.228340           0.089903
4          2048         0.229695           0.090430
click for repro
import torch
import time

from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \
    log_input, capture_logs, capture_logs_with_logging_tensor_mode

def print_table_header():
    print(f"{'b':<10} {'m':<12} {'mm + copy':<18} {'bmm':<18}")

def print_table_row(b, m, time_grad_true, time_grad_false):
    print(f"{b:<10} {m:<12} {time_grad_true:<18.6f} {time_grad_false:<18.6f}")

def measure(B, M):
    def run(x, y):
        out = x @ y
        torch.cuda.synchronize()
        return out

    # [n, k]
    a = torch.rand([768, 768], dtype=torch.half, device="cuda", requires_grad=True)
    a_detach = a.detach()

    # [b, m, n]
    b = torch.rand([B, M, 768], dtype=torch.half, device="cuda").transpose(1, 2).contiguous()

    def fn(t1, t2):
        return t1 @ t2

#    compiled_fn = torch.compile(fn, mode="max-autotune")
#    compiled_fn(a_detach, b)
#
#    def wrapped(t1, t2):
#       fn(t1, t2)
#       torch.cuda.synchronize()
#
#    # warm
#    for _ in range(1000):
#        wrapped(a_detach, b)
#
#    before = time.perf_counter()
#
#    # test
#    for _ in range(1000):
#        wrapped(a_detach, b)
#    after = time.perf_counter()
#    print(after - before)


    # print("require_grad=True")
    # warm
    for _ in range(1000):
        run(a, b)

    with capture_logs_with_logging_tensor_mode() as logs:
       run(a, b)
    # print("\n".join(logs))

    before = time.perf_counter()

    # test
    for _ in range(1000):
        run(a, b)
    after = time.perf_counter()
    t_req_grad_true = after - before


    # print("require_grad=False")
    # warm
    for _ in range(1000):
        run(a_detach, b)

    with capture_logs_with_logging_tensor_mode() as logs:
       run(a_detach, b)
    # print("\n".join(logs))
    before = time.perf_counter()

    # test
    for _ in range(1000):
        run(a_detach, b)
    after = time.perf_counter()
    t_req_grad_false = after - before

    print_table_row(B, M, float(t_req_grad_true), float(t_req_grad_false))
    # print("results allclose: ", torch.allclose(run(a, b), run(a_detach, b)))
    # print("max rel diff: ", ((run(a, b) -  run(a_detach, b)) / (run(a, b))).abs().max())


print_table_header()
measure(2048, 4)
measure(1024, 8)
measure(512, 64)
measure(128, 128)
measure(64, 512)
measure(8, 1024)
measure(4, 2048)

This means that when doing nn.Linear and if your batch size is low, you'll get 2x-3x improvements on that matmul during eval (e.g. no need materialize the large tensor during bw) simply by setting requires_grad=False. (With the PR that caused this regression, it also made it so that running eval with no_grad gives similar benefits for these input sizes)

What is also interesting is that running with compile with max-autotune does NOT help things today! In PT2, we always decompose via the composite implicit kernels, so by the time inductor sees the graph it will either contain mm + copy or bmm (rather than matmul), and inductor today isn't smart enough to transform from one strategy to the other.

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: cublasProblem related to cublas supportmodule: cudaRelated to torch.cuda, and CUDA support in generalmodule: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      pFad - Phonifier reborn

      Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

      Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


      Alternative Proxies:

      Alternative Proxy

      pFad Proxy

      pFad v3 Proxy

      pFad v4 Proxy