Skip to content

expansion adaper to allow language model for llava-next#486

Merged
rzbhatti merged 1 commit into
mainfrom
multi-model-serialization-expansion
Nov 12, 2025
Merged

expansion adaper to allow language model for llava-next#486
rzbhatti merged 1 commit into
mainfrom
multi-model-serialization-expansion

Conversation

@rzbhatti

@rzbhatti rzbhatti commented Nov 7, 2025

Copy link
Copy Markdown
Contributor

This PR fixes the issue #485
For multi-model (like Granite Vision model), the expansion adapter now supports expansion applied to the language model component.

Example call:

serialization.extend_adapter("llava_next", "hf", ["weight_expansion_for_mismatched_head_dim"])

config_dict = {}
config_dict['head_dim'] = 128

model = get_model(
    args.architecture,
    args.variant,
    model_path=args.model_path,
    device_type="cpu" if is_aiu_backend else args.device_type,
    data_type=default_dtype,
    source=args.model_source,
    distributed_strategy=distr_param,
    group=dist.group.WORLD,
    linear_config=linear_config,
    fused_weights=fused_weights,
    override_hf_pretrained_config=True,
    text_config=config_dict
)

Signed-off-by: Rashed Z. Bhatti, PhD <rzbhatti@us.ibm.com>
@rzbhatti rzbhatti requested a review from ani300 November 7, 2025 00:34
@Jordan-Murray22

Jordan-Murray22 commented Nov 7, 2025

Copy link
Copy Markdown

The above code worked for me when combined with the config.py changes in PR #484

Output:

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
NOTICE: Adjusting torch._dynamo.config.cache_size_limit from 8 to 40 to accomodate prompt size of 0 and decode tokens of 20
Fetching 10 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 131482.88it/s]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.6.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.6.attn.dense [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.6.attn.in_pr [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.6.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.16.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.16.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.16.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.16.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.29.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.29.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.29.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.29.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.24.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.24.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.24.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.24.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.27.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.27.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.27.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.27.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.23.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.23.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.23.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.23.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.25.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.25.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.25.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.25.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.38.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.38.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.38.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.38.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.28.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.28.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.28.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.28.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.39.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.39.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.39.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.39.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.12.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.12.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.12.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.12.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.14.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.14.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.14.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.14.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.18.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.18.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.18.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.18.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.10.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.10.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.10.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.10.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.36.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.36.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.36.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.36.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.32.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.32.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.32.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.32.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.34.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.34.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.34.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.34.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.13.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.13.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.13.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.13.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.17.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.17.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.17.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.17.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.37.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.37.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.37.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.37.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.26.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.26.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.26.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.26.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.2.attn.dense [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.2.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.2.attn.in_pr [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.2.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.0.attn.in_pr [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.0.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.0.attn.dense [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.0.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.21.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.21.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.21.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.21.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.7.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.7.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.7.attn.dense [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.7.attn.in_pr [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.9.attn.in_pr [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.9.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.9.attn.dense [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.9.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.35.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.35.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.35.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.35.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.33.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.33.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.33.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.33.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.22.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.22.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.22.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.22.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.1.attn.dense [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.1.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.1.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.1.attn.in_pr [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.4.attn.in_pr [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.4.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.4.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.4.attn.dense [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.5.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.5.attn.dense [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.5.attn.in_pr [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.5.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.19.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.19.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.19.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.19.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.15.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.15.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.15.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.15.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.3.attn.in_pr [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.3.attn.dense [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.3.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.3.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.30.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.30.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.30.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.30.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.20.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.20.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.20.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.20.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.8.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.8.attn.in_pr [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.8.attn.in_pr [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.8.attn.dense [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.11.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.11.attn.dens [2048, 2048] => [2048, 4096]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.11.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.11.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.31.attn.in_p [2048, 2048] => [4096, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.31.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.31.attn.in_p [512, 2048]  => [1024, 2048]
WARNING:fms.utils.serialization:expanding weights of base_model.layers.31.attn.dens [2048, 2048] => [2048, 4096]
<|system|>
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
<|user|>

What animal is shown in this image?
<|assistant|>

) -> Mapping[str, Any]:
new_sd = dict(input_sd)

# For multi model this expansion will be applicable only to the language_model

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this expansion only applicable to a language model?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can work for other models not just language. @rzbhatti can you confirm? We might need to update this comment.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need it only for the language model in the multimodal architecture. If we do not restrict it to a language model, the adapter starts expanding the attention layers of the vision model, too.

@kaoutar55 kaoutar55 self-requested a review November 12, 2025 21:53
@rzbhatti rzbhatti merged commit b8154bf into main Nov 12, 2025
4 checks passed
@rzbhatti rzbhatti deleted the multi-model-serialization-expansion branch November 12, 2025 22:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add support for multi-model serialization expansion adapter for llava-next

5 participants