-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
🐛 Describe the bug
...
File "/home/yyc/.local/miniconda3/envs/zero_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 611, in run
return model(new_inputs)
File "/home/yyc/.local/miniconda3/envs/zero_env/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 892, in _run_from_cache
return compiled_graph.compiled_artifact(inputs)
File "/tmp/torchinductor_yyc/u6/cu6jjzq2vnxkdniqmnw3r7htiuocckvp7ou6qiiyn2ofrtdy7rz7.py", line 503, in call
buf12 = aten._scaled_dot_product_efficient_attention(buf7, buf8, buf9, buf11, True)
File "/home/yyc/.local/miniconda3/envs/zero_env/lib/python3.10/site-packages/torch/_ops.py", line 755, in __call__
return self._op(*args, **(kwargs or {}))
RuntimeError: invalid dtype for bias - should match query's dtype
when using inductor backend.
I'm not able to create a single reproducer but I bisected to the triggering commit: huggingface/transformers@20164cc, it looks to me that the nested autocast
is not being handled correctly.
Full reproducer:
The issue occurs when training llama with transformers==4.38.2
and torch>=2.2.0
, and can be reproduced with a (not very minimized) script as:
TORCHDYNAMO_DEBUG_FUNCTION='forward' TORCH_LOGS='+dynamo' torchrun train.py --model=meta-llama/Llama-2-7b-hf --cc=inductor
import argparse
import gc
import json
import os
import time
import numpy as np
import torch
import torch.distributed as dist
import transformers
from accelerate import init_empty_weights
from datasets import Dataset
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
from transformers import (AutoConfig, AutoModelForCausalLM, LlamaConfig,
PreTrainedModel, Trainer, TrainerCallback,
TrainingArguments)
assert transformers.__version__ >= '4.36', 'requires transformers 4.36+ to enable sdpa by default'
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='gpt2')
parser.add_argument('--mbs', type=int, default=1)
parser.add_argument('--ga', type=int, default=1)
parser.add_argument('--seq-len', type=int, default=512)
parser.add_argument('--iters', type=int, default=100)
parser.add_argument('--cc', default=None)
parser.add_argument('--layers', type=int, default=2)
parser.add_argument('--embd', type=int, default=-1)
parser.add_argument('--vocab-size', type=int, default=-1)
parser.add_argument('--zero-stage', '--zero', type=int, default=0)
parser.add_argument('--overlap-comm', action="store_true", default=False)
parser.add_argument('--local-rank', type=int, default=-1)
parser.add_argument('--print-gc', action="store_true", default=False)
class ThrputCb(TrainerCallback):
def __init__(self, iters, seq_len):
self.t0 = time.perf_counter()
self.total_iters = iters
self.seq_len = seq_len
self.iter = 0
self.smooth_thrput = 0
self.smooth_toks_per_dev = 0
def on_log(self, args, state, control, logs, **kwargs):
_ = logs.pop("total_flos", None)
self.iter += 1
if not state.is_local_process_zero:
return
if self.iter > self.total_iters:
print(
f'🎉 {self.smooth_thrput:.2f} samples/sec, {self.smooth_toks_per_dev:.2f} toks/GPU')
return
world_size = dist.get_world_size() if dist.is_initialized() else 1
now = time.perf_counter()
elapsed = now - self.t0
self.t0 = now
thrput = world_size * args.per_device_train_batch_size * \
args.gradient_accumulation_steps / elapsed
self.smooth_thrput = self.smooth_thrput * 0.9 + thrput * 0.1
toks_per_dev = thrput / world_size * self.seq_len
self.smooth_toks_per_dev = self.smooth_toks_per_dev * 0.9 + toks_per_dev * 0.1
print(f'🐣 Iter {self.iter}/{self.total_iters}: {elapsed*1000:.2f}ms, {thrput:.2f}({self.smooth_thrput:.2f})samples/sec, {toks_per_dev:.2f}({self.smooth_toks_per_dev:.2f})toks/GPU, gc={gc.get_count()}', end=' ')
def print_rank0(s):
if not dist.is_initialized() or dist.get_rank() == 0:
print(s)
def load_model(model_id: str, layers, embd, vocab_size, zero) -> PreTrainedModel:
# Reduce dependence on the network environment for llama
if model_id == 'meta-llama/Llama-2-7b-hf':
config = LlamaConfig(hidden_size=4096, num_hidden_layers=32,
num_attention_heads=32, num_key_value_heads=32, intermediate_size=11008)
elif model_id == 'meta-llama/Llama-2-13b-hf':
config = LlamaConfig(hidden_size=5120, num_hidden_layers=40,
num_attention_heads=40, num_key_value_heads=40, intermediate_size=13824)
elif model_id == 'meta-llama/Llama-2-30b-hf': # not official
config = LlamaConfig(hidden_size=6400, num_hidden_layers=60,
num_attention_heads=40, num_key_value_heads=40, intermediate_size=17280)
elif model_id == 'meta-llama/Llama-2-70b-hf':
config = LlamaConfig(hidden_size=8192, num_hidden_layers=80,
num_attention_heads=64, num_key_value_heads=8, intermediate_size=28672)
else:
if os.path.exists(f'./models/{model_id}'):
model_id = f'./models/{model_id}'
print_rank0(f'💾 Loading model from local {model_id}')
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
if layers != -1:
if hasattr(config, 'num_hidden_layers'):
config.num_hidden_layers = layers
elif hasattr(config, 'num_layers'):
config.num_layers = layers
if embd != -1:
if hasattr(config, 'hidden_size'):
config.hidden_size = embd
elif hasattr(config, 'n_embd'):
config.n_embd = embd
if vocab_size != -1:
if hasattr(config, 'vocab_size'):
config.vocab_size = vocab_size
if zero:
import deepspeed
with deepspeed.zero.Init():
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True)
else:
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True).to_empty(device='cpu')
model.config.pad_token_id = model.config.eos_token_id
return model
def guess_transformer_layer_class(model: torch.nn.Module):
mlist = [m for m in model.modules() if isinstance(m, torch.nn.ModuleList)]
return mlist[0][0].__class__
class CompileOptimizerTrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def create_optimizer_and_scheduler(self, num_training_steps: int):
self.optimizer = torch.optim.AdamW(
self.model.parameters(), lr=torch.scalar_tensor(1e-5), foreach=True, capturable=True)
self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=1000, eta_min=1e-6)
self.optimizer.step = torch.compile(self.optimizer.step)
def train(model, args):
# deepspeed_dict = json.loads(DEEPSPEED_TEMPLATE)
# deepspeed_dict['zero_optimization']['stage'] = args.zero_stage
# deepspeed_dict['zero_optimization']['overlap_comm'] = args.overlap_comm
# deepspeed_dict['wall_clock_breakdown'] = True
input_ids = np.random.randint(100, 30000, (1000, args.seq_len))
data_set = Dataset.from_dict({
"input_ids": input_ids,
"labels": input_ids
})
training_args = TrainingArguments(
log_level='info',
per_device_train_batch_size=args.mbs,
gradient_accumulation_steps=args.ga,
max_steps=args.iters,
save_steps=10e9,
logging_steps=1,
output_dir='./tmp',
disable_tqdm=True,
bf16=True,
torch_compile_backend=args.cc,
)
trainer_cls = CompileOptimizerTrainer if args.cc and args.zero_stage == 0 else Trainer
trainer = trainer_cls(model, args=training_args,
train_dataset=data_set,
callbacks=[ThrputCb(args.iters, args.seq_len)])
trainer.train()
if __name__ == '__main__':
if 'RANK' in os.environ and not dist.is_initialized():
dist.init_process_group(backend='nccl', init_method='env://')
args = parser.parse_args()
print_rank0(f'args: {args}')
if args.zero_stage != 0 and args.fsdp:
raise ValueError('FSDP and ZeRO are mutually exclusive')
if args.print_gc:
gc.set_debug(gc.DEBUG_STATS)
model: PreTrainedModel = load_model(
args.model, args.layers, args.embd, args.vocab_size, args.zero_stage != 0)
print_rank0(f'{model}\nparameters: {model.num_parameters()/1e9:.3f} B')
train(model, args)
Note that --cc=eager
works fine.
Versions
Collecting environment information...
PyTorch version: 2.2.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Alibaba Cloud Linux release 3 (Soaring Falcon) (x86_64)
GCC version: (GCC) 10.2.1 20200825 (Alibaba 10.2.1-3.5 2.32)
Clang version: Could not collect
CMake version: version 3.20.2
Libc version: glibc-2.32
Python version: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.10.134-16.1.al8.x86_64-x86_64-with-glibc2.32
Is CUDA available: True
CUDA runtime version: 12.3.103
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A10
GPU 1: NVIDIA A10
Nvidia driver version: 535.146.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 64
On-line CPU(s) list: 0-63
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 106
Model name: Intel(R) Xeon(R) Platinum 8369B CPU @ 2.90GHz
Stepping: 6
CPU MHz: 2900.000
BogoMIPS: 5800.00
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 48K
L1i cache: 32K
L2 cache: 1280K
L3 cache: 49152K
NUMA node0 CPU(s): 0-63
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm arch_capabilities
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.2.1
[pip3] triton==2.2.0
[conda] numpy 1.26.4 pypi_0 pypi
[conda] torch 2.2.1 pypi_0 pypi
[conda] triton 2.2.0 pypi_0 pypi
cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire