Content-Length: 543107 | pFad | http://github.com/stacklok/codegate/commit/3b109e596059b08f6dd0baf3b005a27e487f272c

64 add scripts for generating embeddings · stacklok/codegate@3b109e5 · GitHub
Skip to content

Commit 3b109e5

Browse files
committed
add scripts for generating embeddings
1 parent e48b1c5 commit 3b109e5

File tree

7 files changed

+66480
-0
lines changed

7 files changed

+66480
-0
lines changed

Diff for: data/archived.jsonl

+9,309
Large diffs are not rendered by default.

Diff for: data/deprecated.jsonl

+31,572
Large diffs are not rendered by default.

Diff for: data/malicious.jsonl

+25,480
Large diffs are not rendered by default.

Diff for: requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
weaviate==0.1.2
2+
weaviate-client==4.9.3
3+
torch==2.5.1
4+
transformers==4.46.3

Diff for: scripts/import_packages.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import json
2+
from utils.embedding_util import generate_embeddings
3+
import weaviate
4+
from weaviate.embedded import EmbeddedOptions
5+
from weaviate.classes.config import Property, DataType
6+
7+
8+
json_files = [
9+
'data/archived.jsonl',
10+
'data/deprecated.jsonl',
11+
'data/malicious.jsonl',
12+
]
13+
14+
15+
def setup_schema(client):
16+
if client.collections.exists("Package"):
17+
client.collections.delete("Package")
18+
client.collections.create(
19+
"Package",
20+
properties=[
21+
Property(name="name", data_type=DataType.TEXT),
22+
Property(name="type", data_type=DataType.TEXT),
23+
Property(name="status", data_type=DataType.TEXT),
24+
Property(name="description", data_type=DataType.TEXT),
25+
]
26+
)
27+
28+
29+
def add_data(client):
30+
collection = client.collections.get("Package")
31+
32+
for json_file in json_files:
33+
with open(json_file, 'r') as f:
34+
print("Adding data from", json_file)
35+
counter = 0
36+
with collection.batch.dynamic() as batch:
37+
for line in f:
38+
package = json.loads(line)
39+
counter += 1
40+
if counter > 100:
41+
break
42+
43+
# prepare the object for embedding
44+
vector_str = f"{package['name']} {package['description']}"
45+
vector = generate_embeddings(vector_str)
46+
47+
# now add the status column
48+
if 'archived' in json_file:
49+
package['status'] = 'archived'
50+
elif 'deprecated' in json_file:
51+
package['status'] = 'deprecated'
52+
elif 'malicious' in json_file:
53+
package['status'] = 'malicious'
54+
else:
55+
package['status'] = 'unknown'
56+
57+
batch.add_object(properties=package, vector=vector)
58+
59+
60+
def run_import():
61+
client = weaviate.WeaviateClient(
62+
embedded_options=EmbeddedOptions(
63+
persistence_data_path="./weaviate_data"
64+
),
65+
)
66+
with client:
67+
client.connect()
68+
print('is_ready:', client.is_ready())
69+
70+
setup_schema(client)
71+
add_data(client)
72+
73+
74+
if __name__ == '__main__':
75+
run_import()

Diff for: utils/__init__.py

Whitespace-only changes.

Diff for: utils/embedding_util.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from transformers import AutoTokenizer, AutoModel
2+
import torch
3+
import torch.nn.functional as F
4+
from torch import Tensor
5+
import os
6+
import warnings
7+
8+
# The transformers library internally is creating this warning, but does not
9+
# impact our app. Safe to ignore.
10+
warnings.filterwarnings(action='ignore', category=ResourceWarning)
11+
12+
13+
# We won't have competing threads in this example app
14+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
15+
16+
17+
# Initialize tokenizer and model for GTE-base
18+
tokenizer = AutoTokenizer.from_pretrained('thenlper/gte-base')
19+
model = AutoModel.from_pretrained('thenlper/gte-base')
20+
21+
22+
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
23+
last_hidden = last_hidden_states.masked_fill(
24+
~attention_mask[..., None].bool(), 0.0)
25+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
26+
27+
28+
def generate_embeddings(text):
29+
inputs = tokenizer(text, return_tensors='pt',
30+
max_length=512, truncation=True)
31+
with torch.no_grad():
32+
outputs = model(**inputs)
33+
34+
attention_mask = inputs['attention_mask']
35+
embeddings = average_pool(outputs.last_hidden_state, attention_mask)
36+
37+
# (Optionally) normalize embeddings
38+
embeddings = F.normalize(embeddings, p=2, dim=1)
39+
40+
return embeddings.numpy().tolist()[0]

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/stacklok/codegate/commit/3b109e596059b08f6dd0baf3b005a27e487f272c

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy