Refactor HF Config -> FMS Model Kwarg Building#494
Conversation
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
| hf_arch_name, *fms_info | ||
| ) | ||
|
|
||
| def register_model_arch_info(self, hf_arch_name, fms_arch_name, param_builder): |
There was a problem hiding this comment.
add docstring for this method as well as type hints
There was a problem hiding this comment.
Added doc strings and type hints for everything 🙂
|
|
||
| class ModelConfigRegistry: | ||
| """Wrapper class that handles converting hf config -> FMS kwargs.""" | ||
| def __init__(self, registry_map=None): |
There was a problem hiding this comment.
add type hint for registry_map. In the docs include what the mapping for registry_map is
|
|
||
| class ModelConfigRegistry: | ||
| """Wrapper class that handles converting hf config -> FMS kwargs.""" | ||
| def __init__(self, registry_map=None): |
There was a problem hiding this comment.
why is the registry_map optional here?
There was a problem hiding this comment.
Consider making explicit why registry_map is optional in the ModelConfigRegistry constructor. If we don’t expect external callers to omit it, mark it as non-optional and remove the guard.
There was a problem hiding this comment.
Initially, I made it optional because I wanted the registry to support out of tree use-cases in case we end up needing it one day for whatever reason, and the map was saved directly instead of split apart (i.e., was None to not have the ref type as a default value).
I agree though - I'll make it not optional and remove the guard, since ultimately someone could just pass an empty dict if they want to do the same thing 🙂
| self.model_param_builders[hf_arch_name] = param_builder | ||
| self.model_arch_mappings[hf_arch_name] = fms_arch_name | ||
|
|
||
| def map_hf_to_fms_arch(self, architecture): |
There was a problem hiding this comment.
add docstring for this method as well as type hints
| return fms_arch | ||
| raise KeyError(f"HF architecture {architecture} is unsupported! Registered architectures: {list(self.model_arch_mappings.keys())}") | ||
|
|
||
| def map_hf_arch_to_fms_params(self, architecture, config): |
There was a problem hiding this comment.
add type hints for this, this will help to differential the hf config from fms
| @@ -0,0 +1,22 @@ | |||
| from functools import partial | |||
| from fms.models.hf.config_utils.config_registry import ModelConfigRegistry | |||
| from fms.models.hf.config_utils.param_builders import * | |||
There was a problem hiding this comment.
Please replace the star import with explicit imports (e.g., import the builder functions individually or import the module as a namespace).
This will improve readability, auto-complete, and static analysis across the stack.
There was a problem hiding this comment.
Sure - done 🙂
|
|
||
| class ModelConfigRegistry: | ||
| """Wrapper class that handles converting hf config -> FMS kwargs.""" | ||
| def __init__(self, registry_map=None): |
There was a problem hiding this comment.
Consider making explicit why registry_map is optional in the ModelConfigRegistry constructor. If we don’t expect external callers to omit it, mark it as non-optional and remove the guard.
| if linear_config: | ||
| config_params["linear_config"] = linear_config | ||
|
|
||
| config_params = _FMS_MODEL_CONFIG_REGISTRY.hf_config_to_fms_config_params( |
There was a problem hiding this comment.
Since _map_model_config is fully removed and replaced with registry-based param builders, can we confirm backward compatibility across all previously supported models (Granite, Mistral, Mixtral, Bamba, Roberta, LlavaNext, MPNet, etc.)?
A short regression test or note confirming parity would help ensure we don't have a regression.
There was a problem hiding this comment.
Yup! I have added a test for regression testing here, just in a separate PR to make the review easier: #493.
This PR has one example for every model architecture and validates that infer_model_configuration produces the same config param kwargs for each model arch, and I used it to validate the correctness of this refactor. We can also add new model architectures to this test as we go to prevent future regressions in the model param 😄
| if architecture in self.model_arch_mappings: | ||
| fms_arch = self.model_arch_mappings[architecture] | ||
| return fms_arch | ||
| raise KeyError(f"HF architecture {architecture} is unsupported! Registered architectures: {list(self.model_arch_mappings.keys())}") |
There was a problem hiding this comment.
Before raising the KeyError, we might want to emit a warning log (logger.warning).
This makes debugging easier when users load partially-supported or custom HF configs.
There was a problem hiding this comment.
Sounds reasonable to me - I changed the two lookups to grab the arch & param builder to warnings, and moved the error to throw it after one or both warnings would log in case those mappings somehow get misaligned
| "tie_heads": config.tie_word_embeddings, | ||
| } | ||
| # Should not have overlap | ||
| assert not any(common_params) in config_params |
There was a problem hiding this comment.
I this is an issue: any(common_params) returns a bool, so this check does not validate key collisions.
Should be:
assert not any(k in config_params for k in common_params).
There was a problem hiding this comment.
🤦 didn't realize that I left that in there, and nice catch, thanks! Rewrote the validation to only raise if there are intersecting keys whose values are conflicting, so it'll raise something like ValueError: Model param builder uses common params, but has conflicting values for key(s) ['src_vocab_size'] instead of a random assert
| @@ -0,0 +1,22 @@ | |||
| from functools import partial | |||
There was a problem hiding this comment.
Could we add a top-level docstring explaining how HF → FMS config mapping works via the registry?
This is now a central extension point, so documenting “how to register a new model” would be very helpful.
There was a problem hiding this comment.
Definitely! Added an in depth explanation for how to map the params here, and can also pull it out into more findable docs if we end up wanting steps for how to implement and test new models in general later on 😄
| # If the recevied the outer siglip model config, pass only the vision | ||
| # encoder config, because we do not care about the text encoder here. | ||
| if hasattr(config, "vision_config"): | ||
| config = config.vision_config |
There was a problem hiding this comment.
This mutates the meaning of config inside the function and may cause confusion or accidental override.
Better to have:
vision_cfg = config.vision_config.
There was a problem hiding this comment.
Sounds good to me 🙂 changed and added a bit of clarification in the doctstring to make it more explicit that the wrapping config (i.e., siglipconfig / llama next config) should be passed, and not the already unwrapped siglipvisionconfig
| raise KeyError(f"HF architecture {architecture} is unsupported! Registered architectures: {list(self.model_arch_mappings.keys())}") | ||
|
|
||
| def hf_config_to_fms_config_params(self, config, model_path): | ||
| architecture = config.architectures[0] |
There was a problem hiding this comment.
HF configs sometimes include multiple architectures.
Should we validate or warn when len(config.architectures) > 1?
Relying on index 0 may not always be correct.
There was a problem hiding this comment.
Good point! Currently we do already do this (here) though, so the behavior is unchanged in this PR.
For now, I added a "The provided HF config supports multiple architectures; only the first, <ARCH_NAME>, will be used in building the FMS model's config params" warning to make the current behavior more clear!
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
340feee to
00f6db8
Compare
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
6dd8b86 to
e7cf3c9
Compare
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
e7cf3c9 to
360d7cc
Compare
|
Thanks for the quick reviews @JRosenkranz and @kaoutar55! I think things should be ready for another look when you have a moment |
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
…k#494) * refactor hf utils to split kwarg building Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * split fms kwarg builders -> separate utils Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * refactor into global registry for model configs Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * run formatting Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * no wild imports in builder imports Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * raise conflicting builder/common params values Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Make registry map required Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * dont shadow config var name in siglip builder Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * warn if we have multiple archs Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Add guidelines for mapping new architectures Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Add docstrings and type hints Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * add config util types Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * move keyerror further in registry mapping Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Add builder docstrings, pull out wrapped imports Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * clarify siglip param builder Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * fix mypy Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * fix tests, make registry private Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> --------- Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Breaks up the big HF -> FMS
_map_model_configchain into a global registry with more well-patterned & reusable param builders. After this change, when adding a new model architecture, we should:build_<new_arch>_paramscallable, which takes as input the HF config, and return the dict of FMSconfig_params._FMS_MODEL_REGISTRYinfms/models/hf/config_utils/__init__.py.main, although probably should be consolidated into the internal testing sub package, or wrapped inscriptslater on) generate the corresponding JSON dict for the params.In some cases, some architectures may have one or two extra kwargs that need to be pulled out of the common config based on the task, e.g., classification for Bert / Roberta - currently these cases are pretty limited, so for now, this is handled by setting the a new flag, (i.e.,
is_classify) with a default to the more common case, and wrapping it in a partial when mapping the less common case in the registry, which allows us to patch the kwarg while keeping the signature standardized. Hopefully this will make things easier to look at!In the future, I think there are a couple of things that could be added to make this more useful as well: