-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
[...] but this is something we can certainly tune in
torch.compile
with max-autotune. cc @ptrblck @msaroufim @eqy @jerryzh168 @csarofeen @xwang233 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov @ezyang @bdhirsh @anijain2305 @zou3519 @peterbell10 @yf225 @ColinPeppler @eellison @shunting314
Originally posted by @lezcano in #118548 (comment)
We can observe a significant difference in runtime for matmul depending on whether requires_grad=True because depending on requires_grad-ness, we will either call bmm and mm underneath.
Some results based on input sizes in the form of [768, 768] @ [B, 768, M]
with fp16 on A100 (runtime in ms):
B M mm + copy bmm
---------------------------------------------------
2048 4 0.152385 1.432394
1024 8 0.147818 0.488963
512 64 0.692234 0.352137
128 128 0.390359 0.129712
64 512 0.773324 0.221456
8 1024 0.228340 0.089903
4 2048 0.229695 0.090430
click for repro
import torch
import time
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \
log_input, capture_logs, capture_logs_with_logging_tensor_mode
def print_table_header():
print(f"{'b':<10} {'m':<12} {'mm + copy':<18} {'bmm':<18}")
def print_table_row(b, m, time_grad_true, time_grad_false):
print(f"{b:<10} {m:<12} {time_grad_true:<18.6f} {time_grad_false:<18.6f}")
def measure(B, M):
def run(x, y):
out = x @ y
torch.cuda.synchronize()
return out
# [n, k]
a = torch.rand([768, 768], dtype=torch.half, device="cuda", requires_grad=True)
a_detach = a.detach()
# [b, m, n]
b = torch.rand([B, M, 768], dtype=torch.half, device="cuda").transpose(1, 2).contiguous()
def fn(t1, t2):
return t1 @ t2
# compiled_fn = torch.compile(fn, mode="max-autotune")
# compiled_fn(a_detach, b)
#
# def wrapped(t1, t2):
# fn(t1, t2)
# torch.cuda.synchronize()
#
# # warm
# for _ in range(1000):
# wrapped(a_detach, b)
#
# before = time.perf_counter()
#
# # test
# for _ in range(1000):
# wrapped(a_detach, b)
# after = time.perf_counter()
# print(after - before)
# print("require_grad=True")
# warm
for _ in range(1000):
run(a, b)
with capture_logs_with_logging_tensor_mode() as logs:
run(a, b)
# print("\n".join(logs))
before = time.perf_counter()
# test
for _ in range(1000):
run(a, b)
after = time.perf_counter()
t_req_grad_true = after - before
# print("require_grad=False")
# warm
for _ in range(1000):
run(a_detach, b)
with capture_logs_with_logging_tensor_mode() as logs:
run(a_detach, b)
# print("\n".join(logs))
before = time.perf_counter()
# test
for _ in range(1000):
run(a_detach, b)
after = time.perf_counter()
t_req_grad_false = after - before
print_table_row(B, M, float(t_req_grad_true), float(t_req_grad_false))
# print("results allclose: ", torch.allclose(run(a, b), run(a_detach, b)))
# print("max rel diff: ", ((run(a, b) - run(a_detach, b)) / (run(a, b))).abs().max())
print_table_header()
measure(2048, 4)
measure(1024, 8)
measure(512, 64)
measure(128, 128)
measure(64, 512)
measure(8, 1024)
measure(4, 2048)
This means that when doing nn.Linear and if your batch size is low, you'll get 2x-3x improvements on that matmul during eval (e.g. no need materialize the large tensor during bw) simply by setting requires_grad=False. (With the PR that caused this regression, it also made it so that running eval with no_grad gives similar benefits for these input sizes)
What is also interesting is that running with compile with max-autotune does NOT help things today! In PT2, we always decompose via the composite implicit kernels, so by the time inductor sees the graph it will either contain mm + copy or bmm (rather than matmul), and inductor today isn't smart enough to transform from one strategy to the other.