-
Notifications
You must be signed in to change notification settings - Fork 57
/
Copy pathcompute_u.py
120 lines (104 loc) · 3.44 KB
/
compute_u.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from pathlib import Path
from typing import Dict, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from rome import repr_tools
from util.globals import *
from .layer_stats import layer_stats
from .rome_hparams import ROMEHyperParams
# Cache variables
inv_mom2_cache = {}
def get_inv_cov(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
layer_name: str,
mom2_dataset: str,
mom2_n_samples: str,
mom2_dtype: str,
) -> torch.Tensor:
"""
Retrieves covariance statistics, then computes the algebraic inverse.
Caches result for future use.
"""
global inv_mom2_cache
model_name = model.config._name_or_path.replace("/", "_")
key = (model_name, layer_name)
if key not in inv_mom2_cache:
print(
f"Retrieving inverse covariance statistics for {model_name} @ {layer_name}. "
f"The result will be cached to avoid repetitive computation."
)
stat = layer_stats(
model,
tok,
layer_name,
STATS_DIR,
mom2_dataset,
to_collect=["mom2"],
sample_size=mom2_n_samples,
precision=mom2_dtype,
)
inv_mom2_cache[key] = torch.inverse(
stat.mom2.moment().to("cuda")
).float() # Cast back to float32
return inv_mom2_cache[key]
def compute_u(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
request: Dict,
hparams: ROMEHyperParams,
layer: int,
context_templates: List[str],
) -> torch.Tensor:
"""
Computes the right vector used in constructing the rank-1 update matrix.
"""
print("Computing left vector (u)...")
# Compute projection token
word_repr_args = dict(
model=model,
tok=tok,
layer=layer,
module_template=hparams.rewrite_module_tmp,
track="in",
)
if "subject_" in hparams.fact_token and hparams.fact_token.index("subject_") == 0:
word = request["subject"]
print(f"Selected u projection object {word}")
cur_repr = repr_tools.get_reprs_at_word_tokens(
context_templates=[
templ.format(request["prompt"]) for templ in context_templates
],
words=[word for _ in range(len(context_templates))],
subtoken=hparams.fact_token[len("subject_") :],
**word_repr_args,
).mean(0)
elif hparams.fact_token == "last":
# Heuristic to choose last word. Not a huge deal if there's a minor
# edge case (e.g. multi-token word) because the function below will
# take the last token.
cur_repr = repr_tools.get_reprs_at_idxs(
contexts=[
templ.format(request["prompt"].format(request["subject"]))
for templ in context_templates
],
idxs=[[-1] for _ in range(len(context_templates))],
**word_repr_args,
).mean(0)
print("Selected u projection token with last token")
else:
raise ValueError(f"fact_token={hparams.fact_token} not recognized")
# Apply inverse second moment adjustment
u = cur_repr
if hparams.mom2_adjustment:
u = get_inv_cov(
model,
tok,
hparams.rewrite_module_tmp.format(layer),
hparams.mom2_dataset,
hparams.mom2_n_samples,
hparams.mom2_dtype,
) @ u.unsqueeze(1)
u = u.squeeze()
return u / u.norm()