Skip to content

activatedgeek/calibration-tuning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

LLM Calibration

This repository contains the code for the paper Large Language Models Must Be Taught to Know What They Don't Know by Sanyam Kapoor* , Nate Gruver* , Manley Roberts, Katherine Collins, Arka Pal, Umang Bhatt, Adrian Weller, Samuel Dooley, Micah Goldblum, and Andrew Gordon Wilson.

Image

We fine-tune various large language models that estimate well-calibrated uncertainties for both multiple-choice and open-ended question answering settings. Fine-tuning uses a dataset of approximately 20,000 generations from each base model labeled for correctness. At test/inference time, the probability of correctness defines the confidence defines the confidence of the model in its answer.

HuggingFace Release

We release the following calibration-tuned models as PEFT adapters via HuggingFace, along with the datasets used to train them.

Open-Ended Generation Llama 2 7B
Llama 2 7B Chat
Llama 2 13B
Llama 2 13B Chat
Mistral 7B
Mistral 7B Instruct
Multiple-Choice Question-Answering Llama 2 7B
Llama 2 7B Chat
Llama 2 13B
Llama 2 13B Chat
Mistral 7B
Mistral 7B Instruct

See experiments/play.py for an example script of how to load and use the models.

Environment Setup

Create a new environment and activate, e.g. with conda,

conda create -y -n calibration-tuning python=3.11 pip -c conda-forge
conda activate calibration-tuning

And finally run,

pip install -e .

This will install the llm package.

NOTE: If a different PyTorch CUDA compilation is required, use extra index repositories. e.g. For CUDA 11.8 run,

pip install --no-cache-dir -U torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118

Usage

All arguments from the main method in each of the scripts below qualify as command line arguments.

Environment Variables:

  • HF_HOME: Path to directory where HuggingFace assets (models and datasets) are cached.
  • OPENAI_API_KEY: OpenAI API key. Used for labeling a generated dataset and evaluations only.
  • CUDA_VISIBLE_DEVICES: Limit the GPU visibility used by the scripts.

Arguments:

  • --model-name: Possible choices are llama2:7b, llama2:7b-chat, llama2:13b, llama2:13b-chat, mistral:7b, mistral:7b-instruct.

Dataset Generation

NOTE: The Story Cloze dataset (2018 version) requires manual download. See instructions here. After getting the CSV file, place it at ${HF_HOME}/datasets/story_cloze/2018.

Output Generation

To create a CSV dataset of open-ended generations at <outputs-log-dir>/outputs. <outputs-log-dir> is auto-generated or can be explicitly specified using --log-dir argument.

python experiments/generate.py outputs --dataset=all_20k_uniform --prompt-style=oe --model-name=llama2:13b-chat --max-new-tokens=30 --batch-size=16 --kshot=1

For multiple-choice generations,

python experiments/generate.py outputs --dataset=all_20k_uniform --prompt-style=choice --model-name=llama2:13b-chat --max-new-tokens=1 --batch-size=16

Uncertainty Query Label Generation

To generate the dataset with uncertainty query labels at <labels-log-dir>/labels (auto-generated or specified via --log-dir),

python experiments/generate.py labels --dataset=offline:<outputs-log-dir>/outputs --model-name=llama2:13b-chat --strategy=substring --batch-size=16

Use --strategy=fuzzy_gpt-3.5-turbo-1106 for generating labels via GPT 3.5 Turbo.

Training

Checkpoints will be saved in an auto-generated directory <train-log-dir>/checkpoint-<step>, or can be configured via --log-dir.

LoRA + Prompt

To use the labeled dataset for calibration-tuning (LoRA + Prompt),

torchrun --nnodes=1 --nproc_per_node=auto experiments/calibration_tune.py --dataset=offline:<labels-log-dir>/labels --model-name=llama2:13b-chat --batch-size=4 --kl-decay=1.0 --max-steps=5000

Use --scale-temp for temperature scaling of the uncertainty query predictions.

For other CLI arguments, see the main function of experiments/calibration_tune.py.

Probe / LoRA

To use the labeled dataset for training a classifier head (Probe),

torchrun --nnodes=1 --nproc_per_node=auto experiments/classifier_tune.py --dataset=offline:<labels-log-dir>/labels --model-name=llama2:13b-chat --batch-size=4 --max-steps=5000

Use --scale-temp for temperature scaling of the classifier. Use --with-lora to enable trainable LoRA parameters (LoRA).

For other CLI arguments, see the main function of experiments/classifier_tune.py.

Evaluation

LoRA + Prompt

For open-ended generation evaluations,

python experiments/evaluate.py --dataset=eval:mmlu --prompt-style=oe --model-name=llama2:13b-chat --query-peft-dir=<train-log-dir>/checkpoint-<step> --mode=query

For multiple-choice question-answering evaluations,

python experiments/evaluate.py --dataset=eval:mmlu --prompt-style=choice --model-name=llama2:13b-chat --query-peft-dir=<train-log-dir>/checkpoint-<step> --mode=query_choice

Use --scale-temp=query to use temperature scaling of the uncertainty query logits.

Probe / LoRA

For open-ended generation evaluations,

python experiments/evaluate.py --dataset=eval:mmlu --prompt-style=oe --model-name=llama2:13b-chat --query-peft-dir=<train-log-dir>/checkpoint-<step> --mode=class --with-classifier

For multiple-choice question-answering evaluations,

python experiments/evaluate.py --dataset=eval:mmlu --prompt-style=choice --model-name=llama2:13b-chat --query-peft-dir=<train-log-dir>/checkpoint-<step> --mode=class_choice --with-classifier

Use --scale-temp=probe to use temperature scaling of the uncertainty query logits.

LICENSE

Apache 2.0

Citation

Please cite our work as:

@inproceedings{kapoor2024llmcalibration,
    title={Large Language Models Must Be Taught to Know What They Don't Know},
    author={Sanyam Kapoor, Nate Gruver, Manley Roberts, Katherine Collins, Arka Pal, Umang Bhatt, Adrian Weller, Samuel Dooley, Micah Goldblum, Andrew Gordon Wilson},
    publisher={arXiv},
    year={2024}
}
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy