Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mem0 Integration: Add mem0 as memory provider for RAG #914

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions llama_stack/providers/registry/tool_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,13 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=["mcp"],
),
),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="mem0",
module="llama_stack.providers.remote.tool_runtime.mem0_memory",
config_class="llama_stack.providers.remote.tool_runtime.mem0_memory.config.Mem0ToolRuntimeConfig",
pip_packages=["mem0"],
),
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Any, Dict

from llama_stack.providers.datatypes import Api

from .config import Mem0ToolRuntimeConfig
from .memory import Mem0MemoryToolRuntimeImpl


async def get_adapter_impl(config: Mem0ToolRuntimeConfig, _deps):
impl = Mem0MemoryToolRuntimeImpl(config)
await impl.initialize()
return impl
19 changes: 19 additions & 0 deletions llama_stack/providers/remote/tool_runtime/mem0_memory/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Optional

from pydantic import BaseModel


class Mem0ToolRuntimeConfig(BaseModel):
"""Configuration for Mem0 Tool Runtime"""

host: Optional[str] = "https://api.mem0.ai"
api_key: Optional[str] = None
top_k: int = 10
org_id: Optional[str] = None
project_id: Optional[str] = None
227 changes: 227 additions & 0 deletions llama_stack/providers/remote/tool_runtime/mem0_memory/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import asyncio
import logging
import secrets
import string
import os
from typing import Any, Dict, List, Optional

from llama_stack.apis.common.content_types import (
InterleavedContent,
TextContentItem,
URL,
)
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolInvocationResult,
ToolRuntime,
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
make_overlapped_chunks,
)

from .config import Mem0ToolRuntimeConfig
from llama_stack.providers.inline.tool_runtime.rag.context_retriever import generate_rag_query

import requests
from urllib.parse import urljoin
import json

log = logging.getLogger(__name__)


def make_random_string(length: int = 8):
return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
)


class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__(
self,
config: Mem0ToolRuntimeConfig,
):
self.config = config

# Mem0 API configuration
self.api_base_url = config.host
self.api_key = config.api_key or os.getenv("MEM0_API_KEY")
self.org_id = config.org_id
self.project_id = config.project_id

# Validate configuration
if not self.api_key:
raise ValueError("Mem0 API Key not provided")
if (self.org_id is not None) != (self.project_id is not None):
raise ValueError("Both org_id and project_id must be provided")

# Setup headers
self.headers = {
"Authorization": f"Token {self.api_key}",
"Content-Type": "application/json",
}

# Validate API key and connection
self._validate_api_connection()

def _validate_api_connection(self):
"""Validate API key and connection by making a test request."""
try:
params = {"org_id": self.org_id, "project_id": self.project_id}
response = requests.get(
urljoin(self.api_base_url, "/v1/ping/"),
headers=self.headers,
params=params,
timeout=10
)
response.raise_for_status()
except requests.exceptions.RequestException as e:
raise ValueError(f"Failed to validate Mem0 API connection: {str(e)}")

async def initialize(self):
pass

async def shutdown(self):
pass

async def insert(
self,
documents: List[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
chunks = []
for doc in documents:
content = await content_from_doc(doc)

# Add to Mem0 memory via API
try:
payload = {
"messages": [{"role": "user", "content": content}],
"metadata": {"document_id": doc.document_id},
"user_id": vector_db_id,
}
if self.org_id and self.project_id:
payload["org_id"] = self.org_id
payload["project_id"] = self.project_id

response = requests.post(
urljoin(self.api_base_url, "/v1/memories/"),
headers=self.headers,
json=payload,
timeout=60
)
print(response.json())
response.raise_for_status()
except requests.exceptions.RequestException as e:
log.error(f"Failed to insert document to Mem0: {str(e)}")
# Continue with vector store insertion even if Mem0 fails

chunks.extend(
make_overlapped_chunks(
doc.document_id,
content,
chunk_size_in_tokens,
chunk_size_in_tokens // 4,
)
)

if not chunks:
return

async def query(
self,
content: InterleavedContent,
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
if not vector_db_ids:
return RAGQueryResult(content=None)

query_config = query_config or RAGQueryConfig()
query = content
print(query)

# Search Mem0 memory via API
mem0_chunks = []
for vector_db_id in vector_db_ids:
try:
payload = {
"query": query,
"user_id": vector_db_id
}
if self.org_id and self.project_id:
payload["org_id"] = self.org_id
payload["project_id"] = self.project_id

response = requests.post(
urljoin(self.api_base_url, "/v1/memories/search/"),
headers=self.headers,
json=payload,
timeout=60
)
print(response.json())
response.raise_for_status()

mem0_results = response.json()
mem0_chunks = [
TextContentItem(
text=f"id:{result.get('metadata', {}).get('document_id', 'unknown')}; content:{result.get('memory', '')}"
)
for result in mem0_results
]
except requests.exceptions.RequestException as e:
log.error(f"Failed to search Mem0: {str(e)}")
# Continue with vector store search even if Mem0 fails

if not mem0_chunks:
return RAGQueryResult(content=None)

return RAGQueryResult(
content=[
TextContentItem(
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
),
*mem0_chunks,
TextContentItem(
text="\n=== END-RETRIEVED-CONTEXT ===\n",
),
],
)

async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
# Parameters are not listed since these methods are not yet invoked automatically
# by the LLM. The method is only implemented so things like /tools can list without
# encountering fatals.
return [
ToolDef(
name="query_from_mem0",
description="Retrieve context from mem0",
),
ToolDef(
name="insert_into_mem0",
description="Insert documents into mem0",
),
]

async def invoke_tool(
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
raise RuntimeError(
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
)
1 change: 1 addition & 0 deletions llama_stack/templates/ollama/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ distribution_spec:
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime
- remote::mem0
image_type: conda
1 change: 1 addition & 0 deletions llama_stack/templates/ollama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_distribution_template() -> DistributionTemplate:
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::mem0",
],
}
name = "ollama"
Expand Down
2 changes: 2 additions & 0 deletions llama_stack/templates/ollama/run-with-safety.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,5 @@ tool_groups:
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
- toolgroup_id: builtin::rag
provider_id: mem0
5 changes: 5 additions & 0 deletions llama_stack/templates/ollama/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ providers:
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: mem0
provider_type: remote::mem0
config: {}
- provider_id: brave-search
provider_type: remote::brave-search
config:
Expand Down Expand Up @@ -110,3 +113,5 @@ tool_groups:
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
Copy link
Contributor

@ehhuang ehhuang Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late response. The problem is that we don't support multiple providers for
client.tool_runtime.rag_tool.insert:

return await self.routing_table.get_provider_impl("knowledge_search").query(

To support this, we might need to add provider_id as an argument, similar to client.vector_dbs.register for example.

However, I also think it might be time to rework this API client.tool_runtime.rag_tool.insert. Will need some time to think through this.

provider_id: code-interpreter
- toolgroup_id: builtin::rag
provider_id: mem0
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy