Skip to content

Commit dd3ea80

Browse files
ishika-misiddhant-mi
authored andcommitted
feature_huggingface: Support opensource llms through HuggingFace
1 parent 2dffc11 commit dd3ea80

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

mindsql/llms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .googlegenai import GoogleGenAi
33
from .llama import LlamaCpp
44
from .open_ai import OpenAi
5+
from .huggingface import HuggingFace

mindsql/llms/huggingface.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import torch
2+
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizerFast
3+
4+
from .illm import ILlm
5+
from .._utils.constants import LLAMA_VALUE_ERROR, LLAMA_PROMPT_EXCEPTION, CONFIG_REQUIRED_ERROR
6+
7+
8+
class HuggingFace(ILlm):
9+
def __init__(self, config=None):
10+
"""
11+
Initialize the class with an optional config parameter.
12+
13+
Parameters:
14+
config (any): The configuration parameter.
15+
16+
Returns:
17+
None
18+
"""
19+
if config is None:
20+
raise ValueError(CONFIG_REQUIRED_ERROR)
21+
22+
if 'model_name' not in config:
23+
raise ValueError(LLAMA_VALUE_ERROR)
24+
model_name = config.pop('model_name') or 'gpt2'
25+
26+
self.tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
27+
self.model = AutoModelForCausalLM.from_pretrained(model_name, **config)
28+
29+
def system_message(self, message: str) -> any:
30+
"""
31+
Create a system message.
32+
33+
Parameters:
34+
message (str): The content of the system message.
35+
36+
Returns:
37+
any: A formatted system message.
38+
39+
Example:
40+
system_msg = system_message("System update: Server maintenance scheduled.")
41+
"""
42+
return {"role": "system", "content": message}
43+
44+
def user_message(self, message: str) -> any:
45+
"""
46+
Create a user message.
47+
48+
Parameters:
49+
message (str): The content of the user message.
50+
51+
Returns:
52+
any: A formatted user message.
53+
"""
54+
return {"role": "user", "content": message}
55+
56+
def assistant_message(self, message: str) -> any:
57+
"""
58+
Create an assistant message.
59+
60+
Parameters:
61+
message (str): The content of the assistant message.
62+
63+
Returns:
64+
any: A formatted assistant message.
65+
"""
66+
return {"role": "assistant", "content": message}
67+
68+
def invoke(self, prompt, **kwargs) -> str:
69+
"""
70+
Submit a prompt to the model for generating a response.
71+
72+
Parameters:
73+
prompt (str): The prompt parameter.
74+
**kwargs: Additional keyword arguments (optional).
75+
- temperature (float): The temperature parameter for controlling randomness in generation.
76+
77+
Returns:
78+
str: The generated response from the model.
79+
"""
80+
if prompt is None or len(prompt) == 0:
81+
raise Exception(LLAMA_PROMPT_EXCEPTION)
82+
83+
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2000)
84+
temperature = kwargs.get("temperature", 0.1)
85+
86+
with torch.no_grad():
87+
output = self.model.generate(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask,
88+
max_length=2000, temperature=temperature,
89+
pad_token_id=self.tokenizer.pad_token_id,
90+
eos_token_id=self.tokenizer.eos_token_id,
91+
bos_token_id=self.tokenizer.bos_token_id, **kwargs)
92+
93+
data = self.tokenizer.decode(output[0], skip_special_tokens=True)
94+
return data

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy