Skip to content

Bamba architecture #10810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 43 commits into from
Closed

Conversation

gabe-l-hart
Copy link
Contributor

Description

This PR adds support for the forthcoming bamba architecture from IBM Research. It is a hybrid SSM architecture which is similar in nature to jamba, but uses the mamba2 mixer instead of the original mamba mixer.

Dependencies

This PR is based on in-flight work for this new model which will be published soon. There are other in-flight PRs for this model:

NOTE: In order to run the conversion steps, you will need to install the branch of transformers from the above PR until it is merged and released.

TODOs

  • Figure out trajectory for the mamba2 PR
  • Fix support on metal
  • Ensure support with cuda

Changes

This PR comes with some big changes in addition to the standard pattern for contributing a new architecture. The need for these big changes arises from how the KV cache is used differently for attention layers versus recurrent layers.

Library Interface

  • Add llama_model_is_hybrid. This mirrors the llama_model_is_recurrent function and currently only returns true for Bamba

llama.cpp project architecture

  • llama_hparams:
    • Add recurrent_layer_arr to support per-layer indicator for recurrence
    • Add bool recurrent_layer(uint32_t il) const to index into recurrent_layer_arr
    • Update n_embd_k_s and n_embd_v_s to be per-layer and return 0 for non-recurrent layers
  • llama_context:
    • Add struct llama_kv_cache kv_hybrid as a secondary KV cache for hybrid models. This is the biggest structural change that allows the recurrent and attention layers to operate independently in the cache. For non-hybrid models, this will never be initialized.
  • llama_kv_cache_init:
    • Add the recurrent argument. This allows recurrent models to initialize a non-recurrent cache (i.e. the attention cache for hybrid models)
    • Determine the size of the cache tensors on a per-layer basis so that layers which are managed by the other cache in a hybrid model have zero size (i.e. in the kv_self, all recurrent layers will have zero size, and in kv_hybrid, all attention layers will have zero size)
  • llm_load_hparams:
    • Automatically fill recurrent_layer_arr based on llama_model_is_recurrent(&model). For hybrid models, it falls to the individual model architecture construction to populate the per-layer entries correctly.
  • llm_build_mamba2:
    • Allow kv_hybrid as the cache that gets initialized instead of kv_self. This is determined based on a new hybrid flag passed by the build_<arch> implementation.
  • llm_build_context:
    • Add replicas of all KV-related members to support the second kv_hybrid cache (kv_hybrid, n_kv_hybrid, kv_head_hybrid, rs_zero_hybrid)
    • Populate all of ^ based on lctx.kv_hybrid
  • llama_set_inputs:
    • Do the recurrent branch for hybrid models and use kv_self or kv_hybrid as appropriate (kv_hybrid will be the recurrent cache for hybrid models)
  • llama_decode_internal:
    • Use simple_split for hybrid models as well, even if kv_self.recurrent is false
    • Add a second llama_kv_slot_restorer for kv_hybrid and perform the save/restore operations if (and only if) the model is hybrid
    • Update kv_hybrid ring buffer when needed
    • Defrag the kv_hybrid cache for hybrid models
  • llama_new_context_with_model:
    • Manage kv_size and kv_size_hybrid independently
    • Init kv_self as recurrent IFF the model is recurrent but not hybrid
    • Init the kv_hybrid cache for hybrid models

Model Architecture

  • Add architecture enum and layer set as normal
  • Add hparam for LLM_KV_ATTENTION_LAYER_INDICES ("%s.attention.layer_indices") to indicate per-layer list of which layers use attention
    • NOTE: Some hybrid models use a period and offset rather than an explicit list, but the list is the most flexible so I opted to only use that at the llama.cpp layer and require conversion from period/offset to a list in the conversion script
  • Add hparam LLM_KV_SSM_HEAD_DIM ("%s.ssm.head_dim") to set the head_dim from config rather than deducing it from d_inner / n_head
  • Add model enum entry in llm_load_tensors
  • Add build_bamba to construct the graph
  • Mark LLM_ARCH_BAMBA as hybrid in llama_model_is_hybrid

Conversion

  • Add Keys.HybridMamba section in constants.py to hold hybrid model parameters
    • NOTE: I'm torn on whether this should be Keys.HybridMamba, Keys.HybridRecurrent, or just Keys.Hybrid
  • Add new hparams with plumbing for ssm_head_dim and attn_layer_indices in constants.py and gguf_writer.py
  • Update tensor_mapping.py names for Bamba layer names
  • Add class BambaModel in convert_hf_to_gguf.py and base it on Mamba2Model

Open Questions

  • What are the plans for the current mamba2 PR?
  • Is there a better approach to handle they hybrid KV caching? I tried going the route of adding additional pointers into the single kv_self, but kept getting hung up on kv_self.size being different for the two different types of caching.
  • Are there other places where the KV cache is used that should be updated to also support kv_hybrid?

compilade and others added 30 commits August 21, 2024 18:00
* ggml : improve ggml_mul speed when masking recurrent states
* ggml : make the ggml_mul fast broadcast path more consistently formatted
The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires
workarounds to work correctly.
The max index is 31, so trimming the arguments is necessary.
Whoops, this is needed for the offset in the concatenated output.
This was initially added because states were masked with ggml_mul,
but this is no longer done and so this "optimisation" is no longer
necessary, or at least not worth the additional code complexity.
This makes the weight buft detection in src/llama.cpp simpler.

* convert : transpose Mamba-2 A, D and reshape SSM_NORM

This breaks existing conversions of Mamba-2 models
to avoid some reshapes.

Not sure if it's a good idea,
but it makes the graph slightly cleaner.

* llama : more appropriate SSM_SCAN and SSM_CONV buft support checks
There are likely still some missing hparams, but the tensor mapping should
be correct

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@ggerganov
Copy link
Member

ggerganov commented Dec 13, 2024

What are the plans for the current mamba2 PR?

I think src/llama.cpp desperately needs some refactoring before we continue to introduce major changes to it. It's time to split it in multiple files and refactor the KV cache implementation to support different modes and be able to add tests. I have started several times to do that, but keep getting side-tracked by some other things.

After finalizing the TTS arch, I will try to finally do this refactoring. Or at the very least - the "split into separate source files" part.

I think mamba2 support should be added after this happens.

Is there a better approach to handle they hybrid KV caching? I tried going the route of adding additional pointers into the single kv_self, but kept getting hung up on kv_self.size being different for the two different types of caching.

Are there other places where the KV cache is used that should be updated to also support kv_hybrid?

It's hard to answer. We are hacking the new KV modes into the original design and things are sometimes hard to fit. That is why we need to reimplement this in order to allow different implementations for different use cases.

@gabe-l-hart
Copy link
Contributor Author

@ggerganov Thanks for the feedback!

It's time to split it in multiple files and refactor the KV cache implementation to support different modes and be able to add tests.

This is music to my ears! I honestly thought about trying to do this myself to help me decompose the problem of adding this hybrid support (knowing it would be throw-away, but a very good learning experience). I know I'm just dipping my toes in at this point, but if there's any help I can offer here, please let me know.

For the time being, unless you'd prefer otherwise, I'll keep this PR open in Draft as a guide to interested parties trying out the Bamba architecture. Once the big refactor moves forward, I'll look into refactoring this work on top of it.

gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request May 2, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request May 13, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@gabe-l-hart gabe-l-hart mentioned this pull request May 13, 2025
9 tasks
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request May 14, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@gabe-l-hart gabe-l-hart mentioned this pull request May 14, 2025
3 tasks
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request May 20, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request May 23, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request May 27, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request May 28, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request May 29, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request May 30, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jun 4, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jun 5, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jun 9, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jun 11, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@compilade compilade mentioned this pull request Jun 12, 2025
16 tasks
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jun 14, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jun 16, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jun 16, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jun 18, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@gabe-l-hart
Copy link
Contributor Author

Closing in favor of #13550

gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jun 19, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jun 26, 2025
This is borrowed and adapted from the original implementation
ggml-org#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
CISC added a commit that referenced this pull request Jul 11, 2025
* wip: llama : separate recurrent states from the KV cache

This will be necessary to support Jamba
(and other recurrent models mixed with Attention).

Doesn't compile yet, and finding a slot isn't yet done correctly for recurrent states.

* llama : use std::find for seq_nodes in llama_rs_cache

* llama : state checkpoints for recurrent models

* llama : correctly handle more edge cases for the rs cache

* llama : rename many llama_kv_cache_* functions

* llama : remove useless return value for some llama_cache_* functions

* llama : rethink recurrent state cell counts

* llama : begin work on support for variable GQA

This will also be useful for Jamba if we consider the Mamba layers
to have 0 KV heads.

* llama : gracefully fail when not finding hybrid slot

* llama : support Jamba

* llama : fix BERT inference without KV cache

* convert-hf : check for unprocessed Jamba experts

* convert-hf : support Mini-Jamba conversion

* llama : fix Jamba quantization sanity checks

* llama : sequence-length-aware batch splitting

* llama : use equal-sequence-length sub-batches for recurrent models

* ggml : simplify SSM-related operators

* llama : make recurrent state slot allocation contiguous

* llama : adapt internal uses of batches to llama_ubatch

* llama : fix batch split output count for embeddings

* llama : minimize swaps when reordering logits

This reduces overhead when running hellaswag
on thousands of sequences with very small 100k params Mamba models.

* llama : fix edge case finding batch seq_id of split recurrent cell

This otherwise was a problem when running the HellaSwag benchmark
with small batch sizes, making it crash.

* llama : avoid copies for simple batch splits

* llama : use im2col and mul_mat to perform convolution for Mamba

This removes the need for ggml_ssm_conv!!!
But performance seems slighly worse on my system,
especially for prompt processing.
Maybe ggml_mul_mat isn't optimized for small row sizes?
More performance testing is necessary until GGML_OP_SSM_CONV is removed.

* ggml : make ggml_ssm_scan not modify its source tensors

* llama : fix shared recurrent tail cell count for small ubatch sizes

Otherwise it was impossible to run the 'parallel' example with '-ub 1'
with a Mamba or Jamba model.

* llama : fix .base() compilation error on Windows

* llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL

* ggml : allow GGML_OP_CONCAT to work on non-contiguous tensors

The implementation already supported it,
and this makes Mamba's conv step slightly faster.

* llama : rename llama_cache to llama_past

This can be changed back later if the name change is wrong.
I was renaming the functions anyway to generalize kv-cache-related
functions to hybrid and recurrent model architectures.
I think llama_past is a better name than llama_cache for a combined
kv cache and recurrent state cache, because the states it contains
pretty much always come before the newly-added ones for any particular
sequence. Also 'llama_past_clear' sounds more obvious in what it does
than 'llama_kv_cache_clear'. The future is what the models generate.
(For embeddings, the kv cache isn't really used anyway)

Still, I'm open to better suggestions.

* examples : replace llama_kv_cache_seq_* with llama_past_seq_*

* mamba : fix non-contiguous usage of ggml_silu

* llama : initial Mamba-2 support

* ggml : SIMD ggml_ssm_scan for Mamba-2

* ggml : improve ggml_mul speed when masking recurrent states

* llama : support running Mamba-Codestral-7B-v0.1

* llama : fix Mamba-2 conv state saving

* ggml : make the ggml_mul fast broadcast path more consistently formatted

* llama : remove unused variable

* llama : add missing break

* convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present

The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires
workarounds to work correctly.

* llama : session saving and reloading for hybrid models

* convert_hf : fix Jamba conversion

* llama : fix mixed signedness comparison

* llama : use unused n_embd_k_gqa in k_shift

This also slightly reduces the diff from the master branch

* llama : begin renaming llama_past back to llama_kv_cache

* llama : avoid redundant state copy for Mamba 1 and 2

* metal : attempt to adapt SSM_SCAN for Mamba-2

* metal : fix SSM_SCAN pipeline scope

* metal : use log and exp instead of log1pf and expf in SSM_SCAN

* metal : remove unused arguments for SSM_SCAN

The max index is 31, so trimming the arguments is necessary.

* metal : add back n_seqs to SSM_SCAN args

Whoops, this is needed for the offset in the concatenated output.

* metal : fix SSM_SCAN state head offset

* metal : fix wrong number of tokens per sequence in SSM_SCAN

* ggml : remove unused fast broadcast path in GGML_MUL

This was initially added because states were masked with ggml_mul,
but this is no longer done and so this "optimisation" is no longer
necessary, or at least not worth the additional code complexity.

* ggml : avoid multiply by D in GGML_OP_SSM_SCAN

This makes the weight buft detection in src/llama.cpp simpler.

* convert : transpose Mamba-2 A, D and reshape SSM_NORM

This breaks existing conversions of Mamba-2 models
to avoid some reshapes.

Not sure if it's a good idea,
but it makes the graph slightly cleaner.

* llama : more appropriate SSM_SCAN and SSM_CONV buft support checks

* convert : fix flake8 lint

* llama : remove implicit recurrent state rollbacks

* llama : partially apply clang-format style

* metal : fix confusion between ; and ,

* metal : add missing args for nb references in ssm_scan_f32_group

* metal : single-user mamba2 inference works

* kv-cache : remove const_cast when setting inputs for s_copy

And also fix multi-user inference for recurrent models
by using cell_id instead of i as the kv cell index
when populating s_copy.

* convert : avoid AutoConfig for Mamba and Mamba2 hparams

* kv-cache : allow context shift for recurrent models

* graph : fix recurrent state copies when avoiding copies

Works, but using lambda functions might not be that clean.

* ggml : fix mamba2 ssm scan when compiled with SVE

* ggml-cpu : reorder SVE FMA for consistency with other SIMD arches

* cuda : implement ssm scan for Mamba2

There is still room for improvement, but it works!

* cuda : adapt Mamba1 ssm scan to shape changes from Mamba2

* feat: Add conversion for Bamba models

This is borrowed and adapted from the original implementation
#10810

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Add Granite 4 conversion

This is a manual copy from my draft branch
https://github.com/gabe-l-hart/llama.cpp/blob/GraniteFourDraft/convert_hf_to_gguf.py#L5076

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Plumb bamba through llama-arch

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Add bamba to llama_arch_is_hybrid_recurrent

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Add optional mamba ssm_in bias tensor

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Add template specialization for get_arr to load a vector<uint32_t> for layer index arr in hparams

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Use an explicit bool to determine mamaba vs mamba2

This allows other architectures like bamba and granitemoehybrid to use
mamab2 without a growing architecture `if` statement inside the mamba
implementation.

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Isolate mamba(2) and granite attention layer building in static methods

This will allow these layer-builder methods to be used from other build
structs without complex inheritance.

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Use per-layer sizes in granite build_attention_layer

Also no need to pass in kv cache since it's already in the inp_attn

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: First (broken) pass at end-to-end Bamba implementation

It generates (garbage) tokens! Still lots of debugging to do.

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Only do Granite multipliers if set

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Pull granite ffn portion into a static function and reuse in hybrid

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat(py): Allow gguf duplicate keys if they match by value and type

This is helpful for hybrid models that want to do gguf param setting by
calling multiple parent classes without needing to make those parent
classes try/except on every attempt to set a gguf value.

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor(py): Simplify granitemoehybrid conversion to use parents better

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Add GRANITE_MOE_HYBRID through llama-arch

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Support GRANITE_MOE_HYBRID in llama-model

This re-uses the Bamba code paths heavily and simply adds the missing parts
for loading MoE and the shared expert.

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* style: Fix flake8 errors

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Fix recurrent cache get after rebase

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Fix hybrid granite implementation for signature changes in build_mamba*_layer

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Refactor relationship between non-hybrid classes and hybrid impl to use mixins

The challenge here is to give both the non-hybrid classes (llm_build_mamba
and llm_build_granite) AND the hybrid class (llm_build_hybrid_mamba) access
to the same intermediate "base class" functionality (build_mamba*_layer,
build_granite_attention_layer) without running into trouble with diamond
inheritance of llm_graph_context. Due to the non-trivial initialization
that happens in llm_graph_context, diamond inheritance results in multiple
initializations of the common base which cause problems around the unique
ptrs. I wanted to get away from `self->` everywhere, but this is still a
bit cleaner than making those methods static I think.

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Implement the full copy-paste version to duplicate the layer builders

This follows the pattern where the type of input is pinned to the type of
memory and that is used to dispatch to the correct version of `build_rs` /
`build_attn`. There's a lot of code duplication that can hopefully be
pulled into common functions in the graph later.

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Rename llm_build_hybrid_mamba -> llm_build_granite_hybrid

I've got back-and-forth a lot about how/if to try to implement reuse of the
"child model" layer types for hybrid models. At the end of the day, I think
hybrid models are their own beast and even if their layers are inspired by
other models, they should maintain control of their own layer building (in
other words, the copy-paste method). Given that, the name should reflect
that this is not a generic hybrid model builder, but rather a granite-
specific hybrid model builder that can do MoE (granite 4) or dense (bamba).

As part if this, I also cleaned up dangling comments from previous attempts
at using static methods for reusability.

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* mamba : fix mismatched new and delete size for llm_build_mamba

Subclasses of llm_graph_context cannot have extra fields,
because the called destructor is not the one from the subclass.
This otherwise would cause problems when runnning Mamba-(1|2) inference
when compiled -DGGML_SANITIZE_ADDRESS=ON

* memory : correctly handle failure in apply()

ggml-ci

* style: Remove TODO for adding first hybrid models to the switch

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Fix bad merge in tensor_mapping.py w/ SSM_NORM

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Fix bad merge resolution with variable renames/moves in llm_build_mamba

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* docs: Fix comment about duplicate key check

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Conform to standard way of initializing inp_out_ids

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* convert : fix jamba conv1d shape squeezing

* fix: Fix input initialization in granite_hybrid after removal of hybrid inputs

Branch: GraniteFourWithJamba

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Use llm_graph_context_mamba in llm_build_granite_hybrid

Branch: GraniteFourWithJamba

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Refactor mamba2/granite/jamba/granite_hybrid relationships as mixins

The key is for the mixin classes (llm_graph_context_mamba,
llm_graph_context_granite) to use virtual inheritance from
llm_graph_context. This allows the common members to exist only once in the
class hierarchy. The downside is that llm_graph_context will be
re-initialized once for each parent (ie 2x for single mixin, 3x for two
mixins, etc...).

Branch: GraniteFourWithJamba

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* graph : add back hybrid memory graph input

But this time it contains the sub-cache graph inputs.
This *should* make it easier to handle updating the inputs
when caching the graph (eventually).

* model : add Jamba to Mamba-specific hparams printing

* fix: Fix input setup after upstream merge

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* jamba : remove redundant nullptr initializations

* model : remove unnecessary prefix for tensor loading constants

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* model : use ggml_swiglu_split for Mamba

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* feat: Add support for dense FFN in GraniteMoeHybrid

This was already partially supported via reusing the granite ffn builder,
and there may be models that leverage this architecture going forward. The
naming is a bit odd, but in the transformers version, it reuses the same
model class and simply has zero regular experts and a single shared expert
(which is the same as a single dense FFN).

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Add support for dense FFN tensor names on c++ side

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Use child inputs for Falcon H1 after merge resolution

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Remove unnecessary prefix on tensor constants

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* model : make falcon-h1 use shared mamba2 layer builder

* memory : avoid referring to KV in recurrent cache logs

* fix: Revert order changes for Falcon H1 to stay consistent with upstream

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* gguf-py : avoid adding duplicate tensor mappings for Jamba

Some of the tensor names are common with Llama4

* refactor: Collapse Bamba and GraniteMoeHybrid into GraniteHybrid

The only key difference is the use of rope which is now set via
rope_finetuned in the hparams

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Remove use of diamond inheritance

Per PR discussion, it's simpler to keep this with basic inheritance and not
introduce the complexity of virtual inheritance and multiple inheritance

#13550 (comment)

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Log mamba params for Granite Hybrid

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Remove unused ssm_in_b

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Remove ATTENTION_LAYER_INDICES hparam in favor of n_head_kv

This matches how recurrent vs attention heads are identified for Jamba

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Remove unused template expansion for get_arr

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Review cleanup in convert_hf_to_gguf

The gist is to be explicit about which base class is being used with the
multiple inheritance setup

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Undo hidden warnings about duplicate identical keys in add_key_value

After further discussion, this encourages sloppy overwriting in the model
converters

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: If not using ROPE, context is "infinite"

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* doc: Add a comment outlining expected duplicate key warnings

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Remove unnecessary duplicate keys in converter

Co-authored-by: Francis Couture-Harpin <git@compilade.net>

(thanks for the sharp eyes and patience!)

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

---------

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Co-authored-by: Francis Couture-Harpin <git@compilade.net>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple Metal https://en.wikipedia.org/wiki/Metal_(API) ggml changes relating to the ggml tensor library for machine learning python python script changes testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy