Skip to content

[WIP] [GSOC] KV Caching for LLM inference #27205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: 5.x
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion modules/dnn/include/opencv2/dnn/dnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ CV__DNN_INLINE_NS_BEGIN
DNN_ARG_INPUT=2, //!< input of the whole model. Before Net::forward() or in Net::forward() all inputs must be set
DNN_ARG_OUTPUT=3, //!< output of the model.
DNN_ARG_TEMP=4, //!< intermediate result, a result of some operation and input to some other operation(s).
DNN_ARG_PATTERN=5 //!< not used for now
DNN_ARG_PATTERN=5, //!< not used for now
DNN_ARG_CACHED=6 //!< cached argument, used in some operations to store intermediate results
};

CV_EXPORTS std::string argKindToString(ArgKind kind);
Expand Down Expand Up @@ -276,6 +277,7 @@ CV__DNN_INLINE_NS_BEGIN
CV_PROP_RW std::vector<Mat> blobs;
std::vector<Arg> inputs;
std::vector<Arg> outputs;
std::vector<Arg> cache;
void* netimpl;

virtual std::vector<Ptr<Graph> >* subgraphs() const;
Expand Down Expand Up @@ -1236,6 +1238,8 @@ CV__DNN_INLINE_NS_BEGIN
const uchar* bufferWeightsPtr, size_t bufferWeightsSize);




/** @brief Reads a network model <a href="https://onnx.ai/">ONNX</a>.
* @param onnxFile path to the .onnx file with text description of the network architecture.
* @param engine select DNN engine to be used. With auto selection the new engine is used first and falls back to classic.
Expand Down
20 changes: 18 additions & 2 deletions modules/dnn/src/net_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ struct Net::Impl : public detail::NetImplBase
std::vector<Mat> scratchBufs;
std::vector<Ptr<Graph> > allgraphs;

struct PageInfo {
// Cache pages
std::vector<cv::Mat> pages;
int curIdx = -1; // index of last filled block
MatShape shape; // shape of a single block (without the batch dim)
MatType dtype;
// int size; // may need this later
};
std::unordered_map<int, PageInfo> cache;

Ptr<Graph> mainGraph;
int globGraphIdx;

Expand All @@ -107,10 +117,8 @@ struct Net::Impl : public detail::NetImplBase
// FIXIT use inheritance
virtual Ptr<BackendWrapper> wrap(Mat& host);


virtual void clear();


virtual void validateBackendAndTarget();

void setUpNet(const std::vector<LayerPin>& blobsToKeep_ = std::vector<LayerPin>());
Expand Down Expand Up @@ -335,6 +343,7 @@ struct Net::Impl : public detail::NetImplBase
Arg getArg(const std::string& name);
bool haveArg(const std::string& name) const;

Arg newCachedArg(const std::string& name, bool allowEmptyName);
Arg newConstArg(const std::string& name, const Mat& m);
Arg newConstScalarArg(const std::string& name, int type, const void* value);
Arg newArg(const std::string& name, ArgKind kind, bool allowEmptyName=false);
Expand All @@ -348,6 +357,11 @@ struct Net::Impl : public detail::NetImplBase

void prepareForInference();

// @TODO
void allocateCache(Arg arg, const MatShape& shape, MatType dtype);
void growCache(Arg arg);
const std::vector<Mat>& getCache(Arg arg) const;

// pre-allocates memory for output tensors.
// if useBufferPool==true, the method uses 'buffers'
// for outputs (according to bufidxs)
Expand Down Expand Up @@ -425,6 +439,8 @@ inline Net::Impl* getNetImpl(const Layer* layer)
return reinterpret_cast<Net::Impl*>(layer->netimpl);
}



Net readNetFromONNX2(const String&);
Net readNetFromONNX2(const char*, size_t);
Net readNetFromONNX2(const std::vector<uchar>&);
Expand Down
63 changes: 62 additions & 1 deletion modules/dnn/src/net_impl2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ std::string argKindToString(ArgKind kind)
kind == DNN_ARG_OUTPUT ? "Output" :
kind == DNN_ARG_TEMP ? "Temp" :
kind == DNN_ARG_PATTERN ? "Pattern" : "???";
}
}

ArgData::ArgData()
{
Expand Down Expand Up @@ -236,6 +236,8 @@ Arg Net::Impl::newConstArg(const std::string& name, const Mat& m)
Arg Net::Impl::newArg(const std::string& name, ArgKind kind, bool allowEmptyName)
{
CV_Assert(allowEmptyName || !name.empty());
CV_Assert(kind != DNN_ARG_CACHED); // use newCachedArg instead

int idx = (int)args.size();

if (!name.empty()) {
Expand All @@ -253,6 +255,27 @@ Arg Net::Impl::newArg(const std::string& name, ArgKind kind, bool allowEmptyName
return Arg(idx);
}

// add new cache Arg
// ArgData for cached args does not hold type and shape for now
// this is held by `PageInfo`
Arg Net::Impl::newCachedArg(const std::string& name, bool allowEmptyName)
{
CV_Assert(allowEmptyName || !name.empty());
int idx = (int)args.size();
argnames.insert(std::make_pair(name, (int64_t)idx));

if (!name.empty()) {
CV_Assert(argnames.find(name) == argnames.end());
argnames.insert(std::make_pair(name, (int64_t)idx));
}

ArgData adata;
adata.name = name;
adata.kind = DNN_ARG_CACHED;
args.push_back(adata);

return Arg(idx);
}

int Net::Impl::findDim(const std::string& dimname, bool insert)
{
Expand Down Expand Up @@ -304,6 +327,44 @@ void Net::Impl::prepareForInference()
}
}

// This is called from a Layer, eg on forward pass,
// when the shape of allocation is known.
// General procedure:
// 1. finds the page list in cache
// 2. `pages`should be empty
// 3. create the first page - a Mat of given type and shape
void Net::Impl::allocateCache(Arg arg, const MatShape& shape, MatType dtype){
auto it = cache.find(arg.idx);
CV_Assert(it != cache.end());
CV_Assert(it->second.pages.empty());
Mat page(shape, dtype);
PageInfo pageInfo;
pageInfo.pages = {page};
pageInfo.curIdx = -1;
pageInfo.shape = shape;
pageInfo.dtype = dtype;
cache[arg.idx] = pageInfo;
}

// add a single page to cache
void Net::Impl::growCache(Arg arg){
auto it = cache.find(arg.idx);
CV_Assert(it != cache.end());
CV_Assert(!it->second.pages.empty());
PageInfo& pageInfo = it->second;

Mat newPage(pageInfo.shape, pageInfo.dtype);
pageInfo.pages.push_back(newPage);
}

// get all pages from cache
const std::vector<Mat>& Net::Impl::getCache(Arg arg) const {
auto it = cache.find(arg.idx);
CV_Assert(it != cache.end());
CV_Assert(!it->second.pages.empty());
return it->second.pages;
}

void Net::Impl::allocateLayerOutputs(
const Ptr<Layer>& layer,
const std::vector<int>& inpTypes,
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy