-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Add Llama support to Inference Plugin #130092
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Llama support to Inference Plugin #130092
Conversation
…r handling and improve error response parsing
…ontent type header
…figuration parsing and serialization
…ration parsing and serialization
… configuration scenarios
…or response handling
…n and truncation behavior
…figuration parsing and serialization
…g-completion # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
…streamline action creation
…d simplify model instantiation
…update similarity measure to DOT_PRODUCT
…g-completion # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
@jonathan-buttner Thank you for your comments. They are addressed and PR ready to be re-reviewed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes, left a few more suggestions.
@@ -212,6 +212,7 @@ static TransportVersion def(int id) { | |||
public static final TransportVersion ESQL_PROFILE_INCLUDE_PLAN_8_19 = def(8_841_0_62); | |||
public static final TransportVersion ESQL_SPLIT_ON_BIG_VALUES_8_19 = def(8_841_0_63); | |||
public static final TransportVersion ESQL_FIXED_INDEX_LIKE_8_19 = def(8_841_0_64); | |||
public static final TransportVersion ML_INFERENCE_LLAMA_ADDED_8_19 = def(8_841_0_65); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I forgot to mention this in the previous review, we won't be backporting this to 8.x so we can remove this transport version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
@@ -116,9 +116,16 @@ public String getWriteableName() { | |||
|
|||
@Override | |||
public TransportVersion getMinimalSupportedVersion() { | |||
assert false : "should never be called when supportsVersion is used"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we can remove this line now because we won't need to backport to 8.x
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
return TransportVersions.ML_INFERENCE_LLAMA_ADDED; | ||
} | ||
|
||
@Override | ||
public boolean supportsVersion(TransportVersion version) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this override.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
@@ -154,9 +154,16 @@ public String getWriteableName() { | |||
|
|||
@Override | |||
public TransportVersion getMinimalSupportedVersion() { | |||
assert false : "should never be called when supportsVersion is used"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this.
return TransportVersions.ML_INFERENCE_LLAMA_ADDED; | ||
} | ||
|
||
@Override |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this method override.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
@@ -49,7 +47,7 @@ public RateLimitSettings rateLimitSettings() { | |||
|
|||
@Override | |||
public int rateLimitGroupingHash() { | |||
return 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. In the future, let's add these bug fix changes to a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure thing!
@@ -141,7 +140,7 @@ public boolean isEnabled() { | |||
return true; | |||
} | |||
|
|||
protected abstract CustomModel createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure); | |||
protected abstract Model createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for these
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem.
} | ||
} | ||
|
||
public void testParseRequestConfig_CreatesChatCompletionsModel() throws IOException { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the base class covers this test, can you check and see if this test covers anything additional, if not, let's remove it from here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are correct. Added check for model_id to assert common model method since it was missed before.
@Jan-Kazlouski-elastic I did some testing, things are looking good. I think there's one scenario we should add better validation error handling for. I was struggling to get the
I think a better experience would be for the PUT request to fail and report back the error it received. This is probably a larger change unrelated to this implementation though. I'll create an issue to improve the validation. |
Could you fix the merge conflicts and then I'll approve and merge on Monday 👍 |
…g-completion # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
…rove error message handling
…g-completion # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
Conflicts are resolved. Adopted changes for Error Handling and for Service constructors from master. |
Creation of new Llama inference provider integration allowing text_embedding, completion (both streaming and non-streaming) and chat_completion (only streaming) to be executed as part of inference API.
Changes were tested locally against next models:
For testing ollama service was used.
Quickstart for setting up running llama service locally: https://llama-stack.readthedocs.io/en/latest/getting_started/index.html
Setup
Install
uv
Download and execute
ollama
https://ollama.com/download
Clone the llama stack repo:
git clone git@github.com:meta-llama/llama-stack.git
, then follow the detailed instructions in the docs above.Running `all-minilm:l6-v2`
Download the model:
Examples of RQ/RS from local testing:
Create Embedding Endpoint
No URL:
No API key (success):
Not Found:
Success:
Perform Embedding
Bad Request:
Success:
Create Completion Endpoint
No URL:
Success:
Perform Completion
Success (Non-Streaming):
Success (Streaming):
Bad Request(Non-Streaming):
Bad Request (Streaming):
Create Chat Completion Endpoint
No URL:
Success:
Perform Chat Completion
Success (basic):
Success (Complex):
Invalid Model:
gradle check
?