From f11b31952618f1ffa2d5763be67e6492e3680e91 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 25 Apr 2026 07:11:31 +0000 Subject: [PATCH 1/6] Initial plan From cdd61881c4c091b34bce9578ba1bf38dc002c370 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 25 Apr 2026 07:18:54 +0000 Subject: [PATCH 2/6] feat: add architecture registration and fallback loading for new models Agent-Logs-Url: https://github.com/codewithdark-git/QuantLLM/sessions/274cc6b0-d42e-47b1-9673-1f6db346ecf2 Co-authored-by: codewithdark-git <144595403+codewithdark-git@users.noreply.github.com> --- README.md | 8 ++ docs/guide/loading-models.md | 36 ++++++ quantllm/__init__.py | 2 + quantllm/core/__init__.py | 3 +- quantllm/core/turbo_model.py | 172 +++++++++++++++++++++++++++- tests/test_architecture_fallback.py | 153 +++++++++++++++++++++++++ 6 files changed, 369 insertions(+), 5 deletions(-) create mode 100644 tests/test_architecture_fallback.py diff --git a/README.md b/README.md index 0555b85..6e9e9b9 100644 --- a/README.md +++ b/README.md @@ -348,6 +348,14 @@ pytest - 📚 Documentation - 🐛 Bug fixes +**Quick template for new architecture support:** +```python +from quantllm import register_architecture, turbo + +register_architecture("new-arch", base_model_type="llama") +model = turbo("org/new-arch-7b", base_model_fallback=True, trust_remote_code=True) +``` + --- ## 📜 License diff --git a/docs/guide/loading-models.md b/docs/guide/loading-models.md index 3bf1010..54398e6 100644 --- a/docs/guide/loading-models.md +++ b/docs/guide/loading-models.md @@ -74,6 +74,42 @@ model = turbo( ) ``` +### New Architecture Fallbacks (for very recent model releases) + +If `transformers` does not recognize a just-released architecture yet, register a fallback family: + +```python +from quantllm import turbo, register_architecture + +# Map new architecture/model_type to a compatible base family +register_architecture("newmodel", base_model_type="llama") + +model = turbo( + "new-model-org/NewModel-7B", + model_type_override="llama", # optional explicit override + base_model_fallback=True, # retry with resolved fallback config + trust_remote_code=True, +) +``` + +You can also load from config only (no checkpoint weights) while waiting for upstream support: + +```python +model = turbo( + "new-model-org/NewModel-7B", + from_config_only=True, + trust_remote_code=True, +) +``` + +#### Fast contribution template for new architectures + +1. Add a registration in your code or PR: + - `register_architecture("new-arch", base_model_type="llama")` +2. Validate loading with: + - `turbo("org/model", base_model_fallback=True, trust_remote_code=True)` +3. Add/extend a focused test in `tests/test_architecture_fallback.py`. + ### Memory Options ```python diff --git a/quantllm/__init__.py b/quantllm/__init__.py index 6f2933b..da0b7fe 100644 --- a/quantllm/__init__.py +++ b/quantllm/__init__.py @@ -35,6 +35,7 @@ from .core import ( turbo, TurboModel, + register_architecture, SmartConfig, HardwareProfiler, ModelAnalyzer, @@ -117,6 +118,7 @@ def show_banner(force: bool = False): # Main API "turbo", "TurboModel", + "register_architecture", "SmartConfig", "HardwareProfiler", "ModelAnalyzer", diff --git a/quantllm/core/__init__.py b/quantllm/core/__init__.py index 5e64f1a..823ca59 100644 --- a/quantllm/core/__init__.py +++ b/quantllm/core/__init__.py @@ -8,7 +8,7 @@ from .hardware import HardwareProfiler from .smart_config import SmartConfig from .model_analyzer import ModelAnalyzer -from .turbo_model import TurboModel, turbo +from .turbo_model import TurboModel, turbo, register_architecture from .compilation import ( compile_model, compile_for_inference, @@ -51,6 +51,7 @@ "ModelAnalyzer", "TurboModel", "turbo", + "register_architecture", # Compilation "compile_model", "compile_for_inference", diff --git a/quantllm/core/turbo_model.py b/quantllm/core/turbo_model.py index 53ec668..79aef24 100644 --- a/quantllm/core/turbo_model.py +++ b/quantllm/core/turbo_model.py @@ -7,7 +7,7 @@ import os import shutil import tempfile -from typing import Optional, Dict, Any, Union, List +from typing import Optional, Dict, Any, Union, List, Type import torch import torch.nn as nn from transformers import ( @@ -32,6 +32,14 @@ "quantization": "Q4_K_M", "push_quantization": None, } +DEFAULT_ARCHITECTURE_FALLBACKS = { + "llama": "llama", + "mistral": "mistral", + "mixtral": "mistral", + "qwen": "qwen2", + "phi": "phi", + "gemma": "gemma", +} class TurboModel: @@ -58,6 +66,9 @@ class TurboModel: >>> model.export("gguf", "my_model.gguf") """ + _architecture_registry: Dict[str, str] = {} + _model_class_registry: Dict[str, Type[PreTrainedModel]] = {} + def __init__( self, model: PreTrainedModel, @@ -82,6 +93,120 @@ def __init__( self._lora_applied = False self.export_push_config = self._build_export_push_config(export_push_config) self.verbose = verbose + + @classmethod + def register_architecture( + cls, + architecture: str, + *, + base_model_type: Optional[str] = None, + model_class: Optional[Type[PreTrainedModel]] = None, + ) -> None: + """ + Register a new architecture alias and optional explicit model class. + + Args: + architecture: Architecture or model type name to register + base_model_type: Base model family to fall back to (e.g., "llama") + model_class: Explicit model class with from_pretrained() + """ + normalized = architecture.lower().strip() + if not normalized: + raise ValueError("architecture must be a non-empty string") + + if base_model_type: + cls._architecture_registry[normalized] = base_model_type.lower().strip() + + if model_class is not None: + cls._model_class_registry[normalized] = model_class + + @classmethod + def resolve_model_type( + cls, + model_name: str, + *, + config_model_type: Optional[str] = None, + model_type_override: Optional[str] = None, + ) -> Optional[str]: + """Resolve model type using override, registry, and default family patterns.""" + if model_type_override: + return model_type_override.lower().strip() + + model_type = (config_model_type or "").lower().strip() + if model_type: + return cls._architecture_registry.get(model_type, model_type) + + name = model_name.lower() + for pattern, fallback in cls._architecture_registry.items(): + if pattern in name: + return fallback + + for pattern, fallback in DEFAULT_ARCHITECTURE_FALLBACKS.items(): + if pattern in name: + return fallback + + return None + + @classmethod + def _load_model_with_fallback( + cls, + model_name: str, + model_kwargs: Dict[str, Any], + *, + trust_remote_code: bool, + hf_config: Optional[Any], + model_type_override: Optional[str], + base_model_fallback: bool, + from_config_only: bool, + ) -> PreTrainedModel: + """Load model with architecture fallback and optional config-only mode.""" + resolved_model_type = cls.resolve_model_type( + model_name, + config_model_type=getattr(hf_config, "model_type", None), + model_type_override=model_type_override, + ) + + if hf_config is not None and resolved_model_type: + setattr(hf_config, "model_type", resolved_model_type) + + if from_config_only: + if hf_config is None: + raise ValueError( + "from_config_only=True requires a loadable config. " + "Try trust_remote_code=True or set model_type_override." + ) + return AutoModelForCausalLM.from_config( + hf_config, + trust_remote_code=trust_remote_code, + torch_dtype=model_kwargs.get("torch_dtype"), + ) + + try: + return AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) + except Exception as primary_error: + if not base_model_fallback: + raise + + if hf_config is not None: + fallback_kwargs = dict(model_kwargs) + fallback_kwargs["config"] = hf_config + try: + return AutoModelForCausalLM.from_pretrained(model_name, **fallback_kwargs) + except Exception: + pass + + if resolved_model_type: + registered_cls = cls._model_class_registry.get(resolved_model_type) + if registered_cls is not None: + class_kwargs = dict(model_kwargs) + if hf_config is not None: + class_kwargs["config"] = hf_config + return registered_cls.from_pretrained(model_name, **class_kwargs) + + raise RuntimeError( + "Failed to load model with AutoModelForCausalLM and fallback resolution. " + "Try register_architecture(...), model_type_override='llama', or from_config_only=True." + ) from primary_error @classmethod def from_pretrained( @@ -96,6 +221,9 @@ def from_pretrained( # Advanced options trust_remote_code: bool = True, quantize: bool = True, + model_type_override: Optional[str] = None, + base_model_fallback: bool = True, + from_config_only: bool = False, config_override: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None, verbose: bool = True, @@ -117,6 +245,9 @@ def from_pretrained( dtype: Override dtype (default: bf16 if available, else fp16) trust_remote_code: Trust remote code in model quantize: Whether to quantize the model + model_type_override: Override detected model_type for very new architectures + base_model_fallback: Retry loading with resolved base model config on failure + from_config_only: Build model from config only (without loading weights) config_override: Dict to override any auto-detected settings config: Shared export/push config (format, quantization, push_format, etc.) verbose: Print loading progress @@ -196,10 +327,19 @@ def from_pretrained( "torch_dtype": smart_config.dtype, } + hf_config = None + # Check if model is already quantized to prevent conflicts try: from transformers import AutoConfig hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) + resolved_model_type = cls.resolve_model_type( + model_name, + config_model_type=getattr(hf_config, "model_type", None), + model_type_override=model_type_override, + ) + if resolved_model_type: + setattr(hf_config, "model_type", resolved_model_type) existing_quant = getattr(hf_config, "quantization_config", None) if existing_quant: @@ -225,7 +365,7 @@ def from_pretrained( pass # Ignore config loading errors, proceed with defaults # Apply quantization if requested - if quantize and smart_config.bits < 16: + if quantize and smart_config.bits < 16 and not from_config_only: model_kwargs.update(cls._get_quantization_kwargs(smart_config)) # Device map for memory management @@ -240,9 +380,14 @@ def from_pretrained( if verbose: task = p.add_task("Downloading & Loading model...", total=None) - model = AutoModelForCausalLM.from_pretrained( + model = cls._load_model_with_fallback( model_name, - **model_kwargs, + model_kwargs, + trust_remote_code=trust_remote_code, + hf_config=hf_config, + model_type_override=model_type_override, + base_model_fallback=base_model_fallback, + from_config_only=from_config_only, ) if verbose: @@ -1892,6 +2037,25 @@ def _replace_with_triton(self, module: nn.Module, bits: int) -> int: return count +def register_architecture( + architecture: str, + *, + base_model_type: Optional[str] = None, + model_class: Optional[Type[PreTrainedModel]] = None, +) -> None: + """ + Register a new architecture alias and optional explicit model class. + + Example: + >>> register_architecture("my-new-model", base_model_type="llama") + """ + TurboModel.register_architecture( + architecture, + base_model_type=base_model_type, + model_class=model_class, + ) + + def turbo( model: str, *, diff --git a/tests/test_architecture_fallback.py b/tests/test_architecture_fallback.py new file mode 100644 index 0000000..ceeb5ca --- /dev/null +++ b/tests/test_architecture_fallback.py @@ -0,0 +1,153 @@ +from types import SimpleNamespace + +import transformers + +from quantllm.core.turbo_model import TurboModel +import quantllm.core.turbo_model as turbo_model_module + + +class _DummySmartConfig(SimpleNamespace): + def print_summary(self): + return None + + +def _make_smart_config(): + return _DummySmartConfig( + bits=16, + effective_loading_bits=16, + dtype="float16", + cpu_offload=False, + device="cpu", + gradient_checkpointing=False, + use_flash_attention=False, + compile_model=False, + ) + + +def _make_tokenizer(): + return SimpleNamespace(pad_token=None, eos_token="", eos_token_id=2) + + +def test_resolve_model_type_detects_common_patterns(): + assert TurboModel.resolve_model_type("meta-llama/Llama-3.2-3B") == "llama" + assert TurboModel.resolve_model_type("Qwen/Qwen3-8B") == "qwen2" + + +def test_register_architecture_maps_new_model_to_base_family(monkeypatch): + monkeypatch.setattr(TurboModel, "_architecture_registry", {}) + monkeypatch.setattr(TurboModel, "_model_class_registry", {}) + TurboModel.register_architecture("newmodel", base_model_type="llama") + + assert TurboModel.resolve_model_type("org/newmodel-7b") == "llama" + + +def test_registered_class_fallback_is_used(monkeypatch): + monkeypatch.setattr(TurboModel, "_architecture_registry", {}) + monkeypatch.setattr(TurboModel, "_model_class_registry", {}) + monkeypatch.setattr( + turbo_model_module.SmartConfig, + "detect", + lambda *args, **kwargs: _make_smart_config(), + ) + monkeypatch.setattr( + turbo_model_module.AutoTokenizer, + "from_pretrained", + lambda *args, **kwargs: _make_tokenizer(), + ) + monkeypatch.setattr( + transformers.AutoConfig, + "from_pretrained", + lambda *args, **kwargs: SimpleNamespace( + model_type="newmodel", + quantization_config=None, + ), + ) + + class _FakeAutoModel: + @staticmethod + def from_pretrained(*args, **kwargs): + raise ValueError("Unrecognized configuration class") + + @staticmethod + def from_config(*args, **kwargs): + return SimpleNamespace(config=SimpleNamespace(model_type="llama")) + + class _RegisteredModel: + called = False + + @classmethod + def from_pretrained(cls, *args, **kwargs): + cls.called = True + return SimpleNamespace(config=SimpleNamespace(model_type="llama")) + + monkeypatch.setattr( + turbo_model_module, + "AutoModelForCausalLM", + _FakeAutoModel, + ) + + TurboModel.register_architecture("newmodel", base_model_type="llama") + TurboModel.register_architecture("llama", model_class=_RegisteredModel) + + loaded = TurboModel.from_pretrained( + "org/newmodel-7b", + quantize=False, + verbose=False, + ) + + assert _RegisteredModel.called is True + assert loaded.model.config.model_type == "llama" + + +def test_from_pretrained_supports_from_config_only(monkeypatch): + monkeypatch.setattr(TurboModel, "_architecture_registry", {}) + monkeypatch.setattr(TurboModel, "_model_class_registry", {}) + monkeypatch.setattr( + turbo_model_module.SmartConfig, + "detect", + lambda *args, **kwargs: _make_smart_config(), + ) + monkeypatch.setattr( + turbo_model_module.AutoTokenizer, + "from_pretrained", + lambda *args, **kwargs: _make_tokenizer(), + ) + monkeypatch.setattr( + transformers.AutoConfig, + "from_pretrained", + lambda *args, **kwargs: SimpleNamespace( + model_type="llama", + quantization_config=None, + ), + ) + + class _FakeAutoModel: + called_from_pretrained = False + called_from_config = False + + @classmethod + def from_pretrained(cls, *args, **kwargs): + cls.called_from_pretrained = True + return SimpleNamespace(config=SimpleNamespace(model_type="llama")) + + @classmethod + def from_config(cls, *args, **kwargs): + cls.called_from_config = True + return SimpleNamespace(config=SimpleNamespace(model_type="llama")) + + monkeypatch.setattr( + turbo_model_module, + "AutoModelForCausalLM", + _FakeAutoModel, + ) + + loaded = TurboModel.from_pretrained( + "org/llama-like-7b", + quantize=False, + verbose=False, + from_config_only=True, + ) + + assert _FakeAutoModel.called_from_pretrained is False + assert _FakeAutoModel.called_from_config is True + assert loaded.model.config.model_type == "llama" From 83b66ae01448d4ffb880c92c5cb25bbb14ab46b1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 25 Apr 2026 07:25:13 +0000 Subject: [PATCH 3/6] docs/tests: polish architecture fallback behavior and guidance Agent-Logs-Url: https://github.com/codewithdark-git/QuantLLM/sessions/274cc6b0-d42e-47b1-9673-1f6db346ecf2 Co-authored-by: codewithdark-git <144595403+codewithdark-git@users.noreply.github.com> --- quantllm/core/turbo_model.py | 86 +++++++++++++++++++++-------- tests/test_architecture_fallback.py | 18 +++--- 2 files changed, 73 insertions(+), 31 deletions(-) diff --git a/quantllm/core/turbo_model.py b/quantllm/core/turbo_model.py index 79aef24..d3b324f 100644 --- a/quantllm/core/turbo_model.py +++ b/quantllm/core/turbo_model.py @@ -5,8 +5,11 @@ """ import os +import re import shutil import tempfile +import copy +from functools import lru_cache from typing import Optional, Dict, Any, Union, List, Type import torch import torch.nn as nn @@ -128,7 +131,12 @@ def resolve_model_type( config_model_type: Optional[str] = None, model_type_override: Optional[str] = None, ) -> Optional[str]: - """Resolve model type using override, registry, and default family patterns.""" + """ + Resolve model type using override, registry, and default family patterns. + + If config_model_type is provided but unregistered, the original config value + is returned unchanged. + """ if model_type_override: return model_type_override.lower().strip() @@ -138,15 +146,37 @@ def resolve_model_type( name = model_name.lower() for pattern, fallback in cls._architecture_registry.items(): - if pattern in name: + if cls._matches_model_name_pattern(name, pattern): return fallback for pattern, fallback in DEFAULT_ARCHITECTURE_FALLBACKS.items(): - if pattern in name: + if cls._matches_model_name_pattern(name, pattern): return fallback return None + @classmethod + def _matches_model_name_pattern(cls, model_name: str, pattern: str) -> bool: + """Return True when pattern appears as a token in model_name.""" + return cls._compiled_model_name_pattern(pattern).search(model_name) is not None + + @staticmethod + @lru_cache(maxsize=256) + def _compiled_model_name_pattern(pattern: str): + """Compile and cache token-boundary regex patterns for model-name matching.""" + escaped = re.escape(pattern) + # Match architecture tokens as standalone chunks split by separators. + return re.compile(rf"(^|[^a-z0-9]){escaped}([^a-z0-9]|$)") + + @staticmethod + def _should_apply_quantization( + quantize: bool, + bits: int, + from_config_only: bool, + ) -> bool: + """Check whether quantization arguments should be added for loading.""" + return quantize and bits < 16 and not from_config_only + @classmethod def _load_model_with_fallback( cls, @@ -165,18 +195,22 @@ def _load_model_with_fallback( config_model_type=getattr(hf_config, "model_type", None), model_type_override=model_type_override, ) + resolved_config = hf_config if hf_config is not None and resolved_model_type: - setattr(hf_config, "model_type", resolved_model_type) + current_model_type = getattr(hf_config, "model_type", None) + if current_model_type != resolved_model_type: + resolved_config = copy.deepcopy(hf_config) + setattr(resolved_config, "model_type", resolved_model_type) if from_config_only: - if hf_config is None: + if resolved_config is None: raise ValueError( "from_config_only=True requires a loadable config. " "Try trust_remote_code=True or set model_type_override." ) return AutoModelForCausalLM.from_config( - hf_config, + resolved_config, trust_remote_code=trust_remote_code, torch_dtype=model_kwargs.get("torch_dtype"), ) @@ -186,27 +220,38 @@ def _load_model_with_fallback( except Exception as primary_error: if not base_model_fallback: raise + fallback_error = None - if hf_config is not None: + if resolved_config is not None: fallback_kwargs = dict(model_kwargs) - fallback_kwargs["config"] = hf_config + fallback_kwargs["config"] = resolved_config try: return AutoModelForCausalLM.from_pretrained(model_name, **fallback_kwargs) - except Exception: - pass + except Exception as fallback_config_error: + fallback_error = fallback_config_error if resolved_model_type: registered_cls = cls._model_class_registry.get(resolved_model_type) if registered_cls is not None: class_kwargs = dict(model_kwargs) - if hf_config is not None: - class_kwargs["config"] = hf_config - return registered_cls.from_pretrained(model_name, **class_kwargs) + if resolved_config is not None: + class_kwargs["config"] = resolved_config + try: + return registered_cls.from_pretrained(model_name, **class_kwargs) + except Exception as fallback_registered_error: + fallback_error = fallback_registered_error + + error_details = f" Last fallback error: {fallback_error}" if fallback_error else "" raise RuntimeError( - "Failed to load model with AutoModelForCausalLM and fallback resolution. " - "Try register_architecture(...), model_type_override='llama', or from_config_only=True." - ) from primary_error + "Failed to load model with AutoModelForCausalLM and fallback resolution.\n" + "Try one of:\n" + "1) Register with register_architecture(...) before loading.\n" + "2) Use model_type_override=''.\n" + "3) Use from_config_only=True with a loadable config " + "(usually trust_remote_code=True)." + + error_details + ) from (fallback_error or primary_error) @classmethod def from_pretrained( @@ -333,13 +378,6 @@ def from_pretrained( try: from transformers import AutoConfig hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) - resolved_model_type = cls.resolve_model_type( - model_name, - config_model_type=getattr(hf_config, "model_type", None), - model_type_override=model_type_override, - ) - if resolved_model_type: - setattr(hf_config, "model_type", resolved_model_type) existing_quant = getattr(hf_config, "quantization_config", None) if existing_quant: @@ -365,7 +403,7 @@ def from_pretrained( pass # Ignore config loading errors, proceed with defaults # Apply quantization if requested - if quantize and smart_config.bits < 16 and not from_config_only: + if cls._should_apply_quantization(quantize, smart_config.bits, from_config_only): model_kwargs.update(cls._get_quantization_kwargs(smart_config)) # Device map for memory management diff --git a/tests/test_architecture_fallback.py b/tests/test_architecture_fallback.py index ceeb5ca..ade044b 100644 --- a/tests/test_architecture_fallback.py +++ b/tests/test_architecture_fallback.py @@ -1,4 +1,5 @@ from types import SimpleNamespace +from unittest.mock import Mock import transformers @@ -30,7 +31,9 @@ def _make_tokenizer(): def test_resolve_model_type_detects_common_patterns(): assert TurboModel.resolve_model_type("meta-llama/Llama-3.2-3B") == "llama" + # Newer Qwen names still fall back to the qwen2 base family. assert TurboModel.resolve_model_type("Qwen/Qwen3-8B") == "qwen2" + assert TurboModel.resolve_model_type("org/custom-arch-1b") is None def test_register_architecture_maps_new_model_to_base_family(monkeypatch): @@ -72,13 +75,14 @@ def from_pretrained(*args, **kwargs): def from_config(*args, **kwargs): return SimpleNamespace(config=SimpleNamespace(model_type="llama")) - class _RegisteredModel: - called = False + registered_call = Mock() - @classmethod - def from_pretrained(cls, *args, **kwargs): - cls.called = True - return SimpleNamespace(config=SimpleNamespace(model_type="llama")) + def _registered_from_pretrained(cls, *args, **kwargs): + registered_call() + return SimpleNamespace(config=SimpleNamespace(model_type="llama")) + + class _RegisteredModel: + from_pretrained = classmethod(_registered_from_pretrained) monkeypatch.setattr( turbo_model_module, @@ -95,7 +99,7 @@ def from_pretrained(cls, *args, **kwargs): verbose=False, ) - assert _RegisteredModel.called is True + assert registered_call.called is True assert loaded.model.config.model_type == "llama" From 2e596061868b710ac7c20d9d5e729bab0100a45e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 25 Apr 2026 13:28:36 +0000 Subject: [PATCH 4/6] Address PR review refinements for architecture fallback loading Agent-Logs-Url: https://github.com/codewithdark-git/QuantLLM/sessions/8867f3b4-18ae-4207-b2e8-51444418c7aa Co-authored-by: codewithdark-git <144595403+codewithdark-git@users.noreply.github.com> --- docs/guide/loading-models.md | 19 +- quantllm/core/turbo_model.py | 1251 ++++++++++++++++----------- tests/test_architecture_fallback.py | 111 ++- 3 files changed, 876 insertions(+), 505 deletions(-) diff --git a/docs/guide/loading-models.md b/docs/guide/loading-models.md index 54398e6..12de986 100644 --- a/docs/guide/loading-models.md +++ b/docs/guide/loading-models.md @@ -87,11 +87,14 @@ register_architecture("newmodel", base_model_type="llama") model = turbo( "new-model-org/NewModel-7B", model_type_override="llama", # optional explicit override - base_model_fallback=True, # retry with resolved fallback config + base_model_fallback=True, # enabled by default; can be disabled trust_remote_code=True, ) ``` +> âš ī¸ **Security note:** `trust_remote_code=True` executes model-provided code. +> Only enable it for trusted publishers, especially when loading unregistered or very new architectures. + You can also load from config only (no checkpoint weights) while waiting for upstream support: ```python @@ -110,6 +113,20 @@ model = turbo( - `turbo("org/model", base_model_fallback=True, trust_remote_code=True)` 3. Add/extend a focused test in `tests/test_architecture_fallback.py`. +#### Real-world style "released yesterday" example + +```python +from quantllm import turbo, register_architecture + +# Example: transformers doesn't recognize Qwen3 yet +register_architecture("qwen3", base_model_type="qwen2") + +model = turbo( + "Qwen/Qwen3-8B", + trust_remote_code=True, +) +``` + ### Memory Options ```python diff --git a/quantllm/core/turbo_model.py b/quantllm/core/turbo_model.py index d3b324f..c355fe0 100644 --- a/quantllm/core/turbo_model.py +++ b/quantllm/core/turbo_model.py @@ -4,29 +4,38 @@ Load, quantize, fine-tune, and export LLMs with one line each. """ +import copy import os import re import shutil import tempfile -import copy from functools import lru_cache -from typing import Optional, Dict, Any, Union, List, Type +from typing import Any, Dict, List, Optional, Type, Union + import torch import torch.nn as nn +from datasets.utils.logging import disable_progress_bar as disable_ds_progress_bar from transformers import ( AutoModelForCausalLM, AutoTokenizer, + GenerationConfig, PreTrainedModel, PreTrainedTokenizer, - GenerationConfig, ) +from transformers.utils.logging import disable_progress_bar as disable_hf_progress_bar -from .smart_config import SmartConfig +from ..utils import ( + QuantLLMProgress, + logger, + print_error, + print_header, + print_info, + print_success, + print_warning, +) from .hardware import HardwareProfiler -from ..utils import logger, print_header, print_success, print_error, print_info, print_warning, QuantLLMProgress -from transformers.utils.logging import disable_progress_bar as disable_hf_progress_bar -from datasets.utils.logging import disable_progress_bar as disable_ds_progress_bar from .memory import memory_optimized_tensor_order +from .smart_config import SmartConfig DEFAULT_CHUNKED_SHARD_SIZE = "2GB" DEFAULT_EXPORT_PUSH_CONFIG = { @@ -48,30 +57,30 @@ class TurboModel: """ High-performance LLM with the simplest possible API. - + Features: - One-line loading with automatic quantization - One-line fine-tuning with LoRA - One-line export to multiple formats - Automatic hardware optimization - + Example: >>> # Load any model in one line >>> model = TurboModel.from_pretrained("meta-llama/Llama-3-8B") - >>> + >>> >>> # Generate text >>> response = model.generate("Explain quantum computing") - >>> + >>> >>> # Fine-tune with your data >>> model.finetune("my_data.json", epochs=3) - >>> + >>> >>> # Export to GGUF >>> model.export("gguf", "my_model.gguf") """ - + _architecture_registry: Dict[str, str] = {} _model_class_registry: Dict[str, Type[PreTrainedModel]] = {} - + def __init__( self, model: PreTrainedModel, @@ -82,7 +91,7 @@ def __init__( ): """ Initialize TurboModel. Use from_pretrained() instead of direct init. - + Args: model: The loaded/quantized model tokenizer: Associated tokenizer @@ -107,7 +116,7 @@ def register_architecture( ) -> None: """ Register a new architecture alias and optional explicit model class. - + Args: architecture: Architecture or model type name to register base_model_type: Base model family to fall back to (e.g., "llama") @@ -116,13 +125,13 @@ def register_architecture( normalized = architecture.lower().strip() if not normalized: raise ValueError("architecture must be a non-empty string") - + if base_model_type: cls._architecture_registry[normalized] = base_model_type.lower().strip() - + if model_class is not None: cls._model_class_registry[normalized] = model_class - + @classmethod def resolve_model_type( cls, @@ -133,26 +142,26 @@ def resolve_model_type( ) -> Optional[str]: """ Resolve model type using override, registry, and default family patterns. - + If config_model_type is provided but unregistered, the original config value is returned unchanged. """ if model_type_override: return model_type_override.lower().strip() - + model_type = (config_model_type or "").lower().strip() if model_type: return cls._architecture_registry.get(model_type, model_type) - + name = model_name.lower() for pattern, fallback in cls._architecture_registry.items(): if cls._matches_model_name_pattern(name, pattern): return fallback - + for pattern, fallback in DEFAULT_ARCHITECTURE_FALLBACKS.items(): if cls._matches_model_name_pattern(name, pattern): return fallback - + return None @classmethod @@ -161,7 +170,7 @@ def _matches_model_name_pattern(cls, model_name: str, pattern: str) -> bool: return cls._compiled_model_name_pattern(pattern).search(model_name) is not None @staticmethod - @lru_cache(maxsize=256) + @lru_cache(maxsize=None) def _compiled_model_name_pattern(pattern: str): """Compile and cache token-boundary regex patterns for model-name matching.""" escaped = re.escape(pattern) @@ -190,19 +199,40 @@ def _load_model_with_fallback( from_config_only: bool, ) -> PreTrainedModel: """Load model with architecture fallback and optional config-only mode.""" + config_model_type = ( + (getattr(hf_config, "model_type", None) or "").lower().strip() + ) + is_registered_architecture = ( + config_model_type in cls._architecture_registry + if config_model_type + else False + ) resolved_model_type = cls.resolve_model_type( model_name, config_model_type=getattr(hf_config, "model_type", None), model_type_override=model_type_override, ) resolved_config = hf_config - + if hf_config is not None and resolved_model_type: current_model_type = getattr(hf_config, "model_type", None) if current_model_type != resolved_model_type: resolved_config = copy.deepcopy(hf_config) setattr(resolved_config, "model_type", resolved_model_type) - + + if ( + trust_remote_code + and config_model_type + and not is_registered_architecture + and config_model_type not in DEFAULT_ARCHITECTURE_FALLBACKS.values() + ): + logger.warning( + "trust_remote_code=True is enabled for unregistered architecture '%s' " + "(resolved fallback: '%s'). Only use this for models from trusted sources.", + config_model_type, + resolved_model_type, + ) + if from_config_only: if resolved_config is None: raise ValueError( @@ -214,22 +244,24 @@ def _load_model_with_fallback( trust_remote_code=trust_remote_code, torch_dtype=model_kwargs.get("torch_dtype"), ) - + try: return AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) except Exception as primary_error: if not base_model_fallback: raise fallback_error = None - + # Fallback priority: resolved config model_type -> explicitly registered model class. if resolved_config is not None: fallback_kwargs = dict(model_kwargs) fallback_kwargs["config"] = resolved_config try: - return AutoModelForCausalLM.from_pretrained(model_name, **fallback_kwargs) + return AutoModelForCausalLM.from_pretrained( + model_name, **fallback_kwargs + ) except Exception as fallback_config_error: fallback_error = fallback_config_error - + if resolved_model_type: registered_cls = cls._model_class_registry.get(resolved_model_type) if registered_cls is not None: @@ -237,22 +269,28 @@ def _load_model_with_fallback( if resolved_config is not None: class_kwargs["config"] = resolved_config try: - return registered_cls.from_pretrained(model_name, **class_kwargs) + return registered_cls.from_pretrained( + model_name, **class_kwargs + ) except Exception as fallback_registered_error: fallback_error = fallback_registered_error - - error_details = f" Last fallback error: {fallback_error}" if fallback_error else "" - + + error_details = ( + f" Last fallback error: {fallback_error}" if fallback_error else "" + ) + architecture_label = config_model_type or "" + resolved_label = resolved_model_type or "" + raise RuntimeError( "Failed to load model with AutoModelForCausalLM and fallback resolution.\n" + f"Architecture '{architecture_label}' resolved to base model type '{resolved_label}'.\n" "Try one of:\n" - "1) Register with register_architecture(...) before loading.\n" - "2) Use model_type_override=''.\n" + f"1) register_architecture('{architecture_label}', base_model_type='llama').\n" + "2) Use model_type_override='llama' (or your compatible base family).\n" "3) Use from_config_only=True with a loadable config " - "(usually trust_remote_code=True)." - + error_details + "(usually trust_remote_code=True)." + error_details ) from (fallback_error or primary_error) - + @classmethod def from_pretrained( cls, @@ -275,13 +313,13 @@ def from_pretrained( ) -> "TurboModel": """ Load a model with automatic optimization. - + This is the main entry point for QuantLLM. It automatically: - Detects your hardware capabilities - Chooses optimal quantization settings - Configures memory management - Enables speed optimizations - + Args: model_name: HuggingFace model name or local path bits: Override quantization bits (default: auto-detect) @@ -296,14 +334,14 @@ def from_pretrained( config_override: Dict to override any auto-detected settings config: Shared export/push config (format, quantization, push_format, etc.) verbose: Print loading progress - + Returns: TurboModel ready for inference or fine-tuning - + Example: >>> # Simplest usage - everything automatic >>> model = TurboModel.from_pretrained("meta-llama/Llama-3-8B") - >>> + >>> >>> # With specific bits >>> model = TurboModel.from_pretrained("mistral-7b", bits=4) >>> @@ -313,10 +351,10 @@ def from_pretrained( # Disable default progress bars disable_hf_progress_bar() disable_ds_progress_bar() - + if verbose: print_header(f"Loading {model_name}") - + # Auto-configure everything if verbose: logger.info("🚀 Detecting hardware and configuration...") @@ -328,9 +366,9 @@ def from_pretrained( device=device, dtype=dtype, ) - + from dataclasses import asdict - + # Apply user overrides if config_override: # Handle SmartConfig objects @@ -338,86 +376,98 @@ def from_pretrained( override_dict = asdict(config_override) else: override_dict = config_override - + for key, value in override_dict.items(): if hasattr(smart_config, key): setattr(smart_config, key, value) - + if verbose: smart_config.print_summary() - + # Load tokenizer if verbose: logger.info("📝 Loading tokenizer...") - + tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=trust_remote_code, ) - + # Ensure pad token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id - + # Load model with optimizations if verbose: - if smart_config.bits != smart_config.effective_loading_bits and smart_config.bits < 16: - logger.info(f"đŸ“Ļ Loading model ({smart_config.effective_loading_bits}-bit, for {smart_config.bits}-bit GGUF export)...") + if ( + smart_config.bits != smart_config.effective_loading_bits + and smart_config.bits < 16 + ): + logger.info( + f"đŸ“Ļ Loading model ({smart_config.effective_loading_bits}-bit, for {smart_config.bits}-bit GGUF export)..." + ) else: logger.info(f"đŸ“Ļ Loading model ({smart_config.bits}-bit)...") - + model_kwargs = { "trust_remote_code": trust_remote_code, "torch_dtype": smart_config.dtype, } - + hf_config = None - + # Check if model is already quantized to prevent conflicts try: from transformers import AutoConfig - hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) - + + hf_config = AutoConfig.from_pretrained( + model_name, trust_remote_code=trust_remote_code + ) + existing_quant = getattr(hf_config, "quantization_config", None) if existing_quant: allow_requantize = False - + # Allow 8-bit -> 4-bit re-quantization (if B&B) is_bnb = "BitsAndBytesConfig" in existing_quant.__class__.__name__ is_8bit = getattr(existing_quant, "load_in_8bit", False) - + if is_bnb and is_8bit and smart_config.bits == 4: allow_requantize = True if is_bnb and is_8bit and smart_config.bits == 4: allow_requantize = True if verbose: logger.info(" â„šī¸ Re-quantizing 8-bit model to 4-bit") - + if not allow_requantize: if verbose: - logger.warning(f"âš ī¸ Model is already quantized ({existing_quant.__class__.__name__}). Disabling dynamic quantization.") + logger.warning( + f"âš ī¸ Model is already quantized ({existing_quant.__class__.__name__}). Disabling dynamic quantization." + ) quantize = False - + except Exception: - pass # Ignore config loading errors, proceed with defaults + pass # Ignore config loading errors, proceed with defaults # Apply quantization if requested - if cls._should_apply_quantization(quantize, smart_config.bits, from_config_only): + if cls._should_apply_quantization( + quantize, smart_config.bits, from_config_only + ): model_kwargs.update(cls._get_quantization_kwargs(smart_config)) - + # Device map for memory management if smart_config.cpu_offload: model_kwargs["device_map"] = "auto" model_kwargs["offload_folder"] = "offload" elif torch.cuda.is_available(): model_kwargs["device_map"] = {"": smart_config.device} - + # Load the model with progress spinner with QuantLLMProgress() as p: if verbose: task = p.add_task("Downloading & Loading model...", total=None) - + model = cls._load_model_with_fallback( model_name, model_kwargs, @@ -427,21 +477,21 @@ def from_pretrained( base_model_fallback=base_model_fallback, from_config_only=from_config_only, ) - + if verbose: p.update(task, completed=100) - + # Apply additional optimizations if smart_config.gradient_checkpointing: - if hasattr(model, 'gradient_checkpointing_enable'): + if hasattr(model, "gradient_checkpointing_enable"): model.gradient_checkpointing_enable() if verbose: logger.info(" ✓ Gradient checkpointing enabled") - + # Enable Flash Attention if available if smart_config.use_flash_attention: cls._enable_flash_attention(model, verbose) - + # Compile model if beneficial if smart_config.compile_model: try: @@ -451,16 +501,16 @@ def from_pretrained( except Exception as e: if verbose: print_warning(f"torch.compile failed: {e}") - + if verbose: print_success("Model loaded successfully!") logger.info("") - + instance = cls(model, tokenizer, smart_config, export_push_config=config) instance._is_quantized = quantize and smart_config.bits < 16 - + return instance - + @classmethod def from_gguf( cls, @@ -469,45 +519,46 @@ def from_gguf( *, device: Optional[str] = None, verbose: bool = True, - **kwargs + **kwargs, ) -> "TurboModel": """ Load a GGUF model directly from HuggingFace Hub or local path. - + This uses transformers' native GGUF support (requires transformers>=4.36.0). - + Args: model_id: HuggingFace repo ID (e.g., "TheBloke/Llama-2-7B-GGUF") or local directory - filename: GGUF filename (e.g., "llama-2-7b.Q4_K_M.gguf"). + filename: GGUF filename (e.g., "llama-2-7b.Q4_K_M.gguf"). If None, tries to auto-find. Use list_gguf_files() to see available options. device: Override device (default: auto-detect best GPU) verbose: Print progress **kwargs: Additional args for AutoModelForCausalLM.from_pretrained - + Returns: TurboModel with loaded GGUF model - + Example: >>> # List available GGUF files in a repo >>> files = TurboModel.list_gguf_files("TheBloke/Llama-2-7B-GGUF") >>> print(files) - >>> + >>> >>> # Load specific quantization >>> model = TurboModel.from_gguf( - ... "TheBloke/Llama-2-7B-GGUF", + ... "TheBloke/Llama-2-7B-GGUF", ... filename="llama-2-7b.Q4_K_M.gguf" ... ) - >>> + >>> >>> # Generate text >>> model.generate("Hello!") """ if verbose: print_header(f"Loading GGUF: {model_id}") - + # Check for GGUF package try: import gguf - gguf_version = getattr(gguf, '__version__', 'unknown') + + gguf_version = getattr(gguf, "__version__", "unknown") if verbose: print_info(f"Using gguf version: {gguf_version}") except ImportError: @@ -516,7 +567,7 @@ def from_gguf( "Loading GGUF models requires the 'gguf' package.\n" "Please run: pip install gguf>=0.10.0" ) - + # If no filename specified, try to find one if filename is None: if verbose: @@ -525,31 +576,33 @@ def from_gguf( available_files = cls.list_gguf_files(model_id) if available_files: # Prefer Q4_K_M if available, otherwise take first - q4_files = [f for f in available_files if 'q4_k_m' in f.lower()] + q4_files = [f for f in available_files if "q4_k_m" in f.lower()] filename = q4_files[0] if q4_files else available_files[0] if verbose: - print_info(f"Found {len(available_files)} GGUF files, using: {filename}") + print_info( + f"Found {len(available_files)} GGUF files, using: {filename}" + ) except Exception as e: if verbose: print_warning(f"Could not list GGUF files: {e}") - + if verbose and filename: print_info(f"Loading: {filename}") - + smart_config = SmartConfig.detect(model_id, device=device) smart_config.quant_type = "GGUF" - + with QuantLLMProgress() as progress: if verbose: task = progress.add_task("Loading GGUF model...", total=None) - + try: model = AutoModelForCausalLM.from_pretrained( model_id, gguf_file=filename, torch_dtype=smart_config.dtype, trust_remote_code=True, - **kwargs + **kwargs, ) except ImportError as e: if "gguf" in str(e).lower(): @@ -566,7 +619,7 @@ def from_gguf( f"Available files: {available}" ) from e raise - + # Load tokenizer try: tokenizer = AutoTokenizer.from_pretrained(model_id, gguf_file=filename) @@ -576,32 +629,34 @@ def from_gguf( except Exception: # Some GGUF repos might not have tokenizer, try base model if verbose: - print_warning("Could not load tokenizer from GGUF repo, using default") + print_warning( + "Could not load tokenizer from GGUF repo, using default" + ) tokenizer = None - + if verbose: print_success("GGUF Model loaded successfully!") - + # Print model info - if hasattr(model, 'num_parameters'): + if hasattr(model, "num_parameters"): params = model.num_parameters() / 1e9 print_info(f"Parameters: {params:.2f}B") - + instance = cls(model, tokenizer, smart_config, verbose=verbose) instance._is_quantized = True return instance - + @staticmethod def list_gguf_files(model_id: str) -> List[str]: """ List available GGUF files in a HuggingFace repository. - + Args: model_id: HuggingFace repo ID (e.g., "TheBloke/Llama-2-7B-GGUF") - + Returns: List of GGUF filenames available in the repository - + Example: >>> files = TurboModel.list_gguf_files("TheBloke/Llama-2-7B-GGUF") >>> print(files) @@ -609,50 +664,60 @@ def list_gguf_files(model_id: str) -> List[str]: """ try: from huggingface_hub import list_repo_files - + all_files = list_repo_files(model_id) - gguf_files = [f for f in all_files if f.endswith('.gguf')] - + gguf_files = [f for f in all_files if f.endswith(".gguf")] + # Sort by quantization quality (Q4_K_M before Q2_K, etc.) def quant_sort_key(name): name_lower = name.lower() # Higher number = better quality, listed first - if 'f32' in name_lower: return 0 - if 'f16' in name_lower: return 1 - if 'q8' in name_lower: return 2 - if 'q6' in name_lower: return 3 - if 'q5_k_m' in name_lower: return 4 - if 'q5_k_s' in name_lower: return 5 - if 'q4_k_m' in name_lower: return 6 - if 'q4_k_s' in name_lower: return 7 - if 'q3_k' in name_lower: return 8 - if 'q2_k' in name_lower: return 9 + if "f32" in name_lower: + return 0 + if "f16" in name_lower: + return 1 + if "q8" in name_lower: + return 2 + if "q6" in name_lower: + return 3 + if "q5_k_m" in name_lower: + return 4 + if "q5_k_s" in name_lower: + return 5 + if "q4_k_m" in name_lower: + return 6 + if "q4_k_s" in name_lower: + return 7 + if "q3_k" in name_lower: + return 8 + if "q2_k" in name_lower: + return 9 return 10 - + return sorted(gguf_files, key=quant_sort_key) - + except Exception as e: # If it's a local path, list directory if os.path.isdir(model_id): - return [f for f in os.listdir(model_id) if f.endswith('.gguf')] + return [f for f in os.listdir(model_id) if f.endswith(".gguf")] raise ValueError(f"Could not list GGUF files from {model_id}: {e}") @staticmethod def _get_quantization_kwargs(config: SmartConfig) -> Dict[str, Any]: """ Get kwargs for quantized model loading. - + Note: BitsAndBytes only supports 4-bit and 8-bit quantization for loading. Other bit widths (2, 3, 5, 6) are only available during GGUF export. - + For loading: - bits <= 4: Uses 4-bit NF4 quantization - - bits 5-7: Uses 8-bit quantization + - bits 5-7: Uses 8-bit quantization - bits >= 8: Uses 8-bit quantization """ try: from transformers import BitsAndBytesConfig - + # BitsAndBytes only supports 4-bit and 8-bit # Map requested bits to available options if config.bits <= 4: @@ -665,8 +730,12 @@ def _get_quantization_kwargs(config: SmartConfig) -> Dict[str, Any]: bnb_4bit_use_double_quant=True, ) if config.bits < 4: - logger.info(f" â„šī¸ BitsAndBytes supports 4/8-bit only. Using 4-bit for requested {config.bits}-bit.") - logger.info(f" Tip: Export to GGUF for Q{config.bits}_K quantization!") + logger.info( + f" â„šī¸ BitsAndBytes supports 4/8-bit only. Using 4-bit for requested {config.bits}-bit." + ) + logger.info( + f" Tip: Export to GGUF for Q{config.bits}_K quantization!" + ) else: # 5, 6, 7, 8-bit requests -> 8-bit effective_bits = 8 @@ -674,11 +743,15 @@ def _get_quantization_kwargs(config: SmartConfig) -> Dict[str, Any]: load_in_8bit=True, ) if config.bits != 8: - logger.info(f" â„šī¸ BitsAndBytes supports 4/8-bit only. Using 8-bit for requested {config.bits}-bit.") - logger.info(f" Tip: Export to GGUF for Q{config.bits}_K quantization!") - + logger.info( + f" â„šī¸ BitsAndBytes supports 4/8-bit only. Using 8-bit for requested {config.bits}-bit." + ) + logger.info( + f" Tip: Export to GGUF for Q{config.bits}_K quantization!" + ) + return {"quantization_config": quantization_config} - + except ImportError: logger.warning("⚠ bitsandbytes not installed, loading without quantization") return {} @@ -706,20 +779,20 @@ def _build_export_push_config(config: Optional[Dict[str, Any]]) -> Dict[str, Any resolved["push_quantization"] = resolved["quantization"] return resolved - + @staticmethod def _enable_flash_attention(model: PreTrainedModel, verbose: bool = True) -> None: """Enable Flash Attention if available.""" try: # Try to use native Flash Attention 2 - if hasattr(model, 'config'): + if hasattr(model, "config"): model.config._attn_implementation = "flash_attention_2" if verbose: logger.info(" ✓ Flash Attention 2 enabled") except Exception: if verbose: logger.warning(" ⚠ Flash Attention not available") - + def generate( self, prompt: str, @@ -736,7 +809,7 @@ def generate( ) -> str: """ Generate text from a prompt. - + Args: prompt: Input text prompt max_new_tokens: Maximum tokens to generate @@ -748,18 +821,18 @@ def generate( repetition_penalty: Penalty for repeating tokens (>1.0 = less repetition) stop_strings: List of strings that stop generation **kwargs: Additional generation parameters - + Returns: Generated text response - + Example: >>> response = model.generate("Explain quantum computing") - >>> + >>> >>> # With streaming >>> response = model.generate("Tell me a story", stream=True) """ import sys - + # Tokenize input inputs = self.tokenizer( prompt, @@ -767,20 +840,27 @@ def generate( truncation=True, max_length=self.config.max_seq_length - max_new_tokens, ) - + # Move to device inputs = {k: v.to(self.model.device) for k, v in inputs.items()} - + # Default stop strings stop_strings = stop_strings or [] - + # Streaming generation if stream: return self._generate_streaming( - inputs, max_new_tokens, temperature, top_p, top_k, - do_sample, repetition_penalty, stop_strings, **kwargs + inputs, + max_new_tokens, + temperature, + top_p, + top_k, + do_sample, + repetition_penalty, + stop_strings, + **kwargs, ) - + # Non-streaming generation with torch.inference_mode(): outputs = self.model.generate( @@ -795,19 +875,19 @@ def generate( repetition_penalty=repetition_penalty, **kwargs, ) - + # Decode, removing the prompt - generated = outputs[0][inputs["input_ids"].shape[1]:] + generated = outputs[0][inputs["input_ids"].shape[1] :] response = self.tokenizer.decode(generated, skip_special_tokens=True) - + # Check for stop strings and truncate for stop in stop_strings: if stop in response: response = response.split(stop)[0] break - + return response.strip() - + def _generate_streaming( self, inputs: Dict, @@ -822,17 +902,18 @@ def _generate_streaming( ) -> str: """Generate with streaming output.""" import sys - + try: - from transformers import TextIteratorStreamer from threading import Thread - + + from transformers import TextIteratorStreamer + streamer = TextIteratorStreamer( - self.tokenizer, + self.tokenizer, skip_prompt=True, skip_special_tokens=True, ) - + generation_kwargs = { **inputs, "max_new_tokens": max_new_tokens, @@ -846,18 +927,18 @@ def _generate_streaming( "streamer": streamer, **kwargs, } - + # Run generation in background thread thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() - + # Stream output generated_text = [] for new_text in streamer: sys.stdout.write(new_text) sys.stdout.flush() generated_text.append(new_text) - + # Check stop strings full_text = "".join(generated_text) should_stop = False @@ -867,17 +948,17 @@ def _generate_streaming( break if should_stop: break - + thread.join() print() # New line after streaming - + response = "".join(generated_text) for stop in stop_strings: if stop in response: response = response.split(stop)[0] - + return response.strip() - + except ImportError: # Fallback to non-streaming print("(Streaming not available, using batch generation)") @@ -887,7 +968,7 @@ def _generate_streaming( temperature=temperature, stream=False, ) - + def chat( self, messages: List[Dict[str, str]], @@ -895,14 +976,14 @@ def chat( ) -> str: """ Generate response for chat-format messages. - + Args: messages: List of {"role": "user/assistant/system", "content": "..."} **kwargs: Additional generation parameters - + Returns: Assistant's response - + Example: >>> response = model.chat([ ... {"role": "system", "content": "You are a helpful assistant."}, @@ -911,9 +992,12 @@ def chat( """ # Try to apply chat template prompt = None - + # First, try native chat template - if hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template is not None: + if ( + hasattr(self.tokenizer, "chat_template") + and self.tokenizer.chat_template is not None + ): try: prompt = self.tokenizer.apply_chat_template( messages, @@ -922,25 +1006,25 @@ def chat( ) except Exception: pass - + # If no template, use a sensible default if prompt is None: # Default chat format (works for most models) parts = [] for m in messages: - role = m.get('role', 'user') - content = m.get('content', '') - if role == 'system': + role = m.get("role", "user") + content = m.get("content", "") + if role == "system": parts.append(f"System: {content}\n") - elif role == 'user': + elif role == "user": parts.append(f"User: {content}\n") - elif role == 'assistant': + elif role == "assistant": parts.append(f"Assistant: {content}\n") parts.append("Assistant:") prompt = "".join(parts) - + return self.generate(prompt, **kwargs) - + def finetune( self, data: Union[str, List[Dict[str, str]], Any], @@ -956,7 +1040,7 @@ def finetune( ) -> Dict[str, Any]: """ Fine-tune the model with LoRA. - + Args: data: Training data - file path, list of dicts, or HF dataset epochs: Number of training epochs (default: auto) @@ -967,7 +1051,7 @@ def finetune( output_dir: Where to save (default: ./output/{model_name}) hub_manager: QuantLLMHubManager instance for auto-tracking **kwargs: Additional training arguments - + Returns: Training results dictionary """ @@ -977,33 +1061,40 @@ def finetune( os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" if "WANDB_DISABLED" not in os.environ: os.environ["WANDB_DISABLED"] = "true" - + # Suppress noise import warnings + warnings.filterwarnings("ignore", module="peft") warnings.filterwarnings("ignore", category=FutureWarning) - + try: from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training - from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling + from transformers import ( + DataCollatorForLanguageModeling, + Trainer, + TrainingArguments, + ) except ImportError: - raise ImportError("peft is required for fine-tuning. Install with: pip install peft") - + raise ImportError( + "peft is required for fine-tuning. Install with: pip install peft" + ) + # 2. Prepare model for training if self._is_quantized: self.model = prepare_model_for_kbit_training(self.model) - + # Enable gradient checkpointing if configured if self.config.gradient_checkpointing: self.model.gradient_checkpointing_enable() - + # 3. Auto-configure LoRA r = lora_r or (16 if self.model.num_parameters() < 10e9 else 64) alpha = lora_alpha or (r * 2) - + # Determine target modules based on model type target_modules = self._get_lora_target_modules() - + lora_config = LoraConfig( r=r, lora_alpha=alpha, @@ -1012,37 +1103,43 @@ def finetune( bias="none", task_type="CAUSAL_LM", ) - + # Apply LoRA if not already applied if not self._lora_applied: self.model = get_peft_model(self.model, lora_config) self._lora_applied = True - + trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad) total = sum(p.numel() for p in self.model.parameters()) - print_info(f"LoRA applied: {trainable:,} trainable params ({100*trainable/total:.2f}%)") - + print_info( + f"LoRA applied: {trainable:,} trainable params ({100*trainable/total:.2f}%)" + ) + # 4. Load and prepare data train_dataset = self._prepare_dataset(data) - + # 5. Auto-configure training settings epochs = epochs or 3 batch_size = batch_size or self.config.batch_size learning_rate = learning_rate or 2e-4 - output_dir = output_dir or f"./output/{self.model.config._name_or_path.split('/')[-1]}" - + output_dir = ( + output_dir or f"./output/{self.model.config._name_or_path.split('/')[-1]}" + ) + # Auto-track parameters if hub_manager provided if hub_manager: - hub_manager.track_hyperparameters({ - "epochs": epochs, - "learning_rate": learning_rate, - "batch_size": batch_size, - "lora_r": r, - "lora_alpha": alpha, - "base_model": getattr(self.config, "model_name", "unknown"), - "output_dir": output_dir - }) - + hub_manager.track_hyperparameters( + { + "epochs": epochs, + "learning_rate": learning_rate, + "batch_size": batch_size, + "lora_r": r, + "lora_alpha": alpha, + "base_model": getattr(self.config, "model_name", "unknown"), + "output_dir": output_dir, + } + ) + # Training loop try: training_args = TrainingArguments( @@ -1062,99 +1159,138 @@ def finetune( torch_compile=self.config.compile_model, **kwargs, ) - + trainer = Trainer( model=self.model, args=training_args, train_dataset=train_dataset, - tokenizer=self.tokenizer, # Use new argument name - data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + tokenizer=self.tokenizer, # Use new argument name + data_collator=DataCollatorForLanguageModeling( + self.tokenizer, mlm=False + ), ) - + result = trainer.train() self._is_finetuned = True - + print_success(f"Training complete! Model saved to {output_dir}") - + return { "train_loss": result.training_loss, "epochs": epochs, "output_dir": output_dir, "learning_rate": learning_rate, "batch_size": batch_size, - "lora_r": r + "lora_r": r, } - + except Exception as e: print_error(f"Training failed: {e}") raise # Hint about OOM if "out of memory" in str(e).lower(): - print_info("Tip: Try reducing batch_size or enabling gradient_checkpointing in config.") + print_info( + "Tip: Try reducing batch_size or enabling gradient_checkpointing in config." + ) raise - + def _get_lora_target_modules(self) -> List[str]: """Get appropriate LoRA target modules for the model.""" - model_type = getattr(self.model.config, 'model_type', '').lower() - + model_type = getattr(self.model.config, "model_type", "").lower() + # Common patterns by model type target_patterns = { - "llama": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], - "mistral": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], - "qwen2": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + "llama": [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + "mistral": [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + "qwen2": [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], "phi": ["q_proj", "k_proj", "v_proj", "dense", "fc1", "fc2"], - "gemma": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + "gemma": [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], } - + return target_patterns.get(model_type, ["q_proj", "v_proj"]) - + def _prepare_dataset(self, data: Union[str, List[Dict], Any]) -> Any: """Prepare dataset for training.""" from datasets import Dataset, load_dataset - + if isinstance(data, str): # Load from file - if data.endswith('.json') or data.endswith('.jsonl'): - dataset = load_dataset('json', data_files=data)['train'] + if data.endswith(".json") or data.endswith(".jsonl"): + dataset = load_dataset("json", data_files=data)["train"] else: - dataset = load_dataset(data)['train'] + dataset = load_dataset(data)["train"] elif isinstance(data, list): dataset = Dataset.from_list(data) else: dataset = data # Assume it's already a Dataset - + # Tokenize def tokenize_function(examples): # Handle different data formats - if 'text' in examples: - texts = examples['text'] - elif 'instruction' in examples and 'output' in examples: + if "text" in examples: + texts = examples["text"] + elif "instruction" in examples and "output" in examples: texts = [ f"### Instruction:\n{inst}\n\n### Response:\n{out}" - for inst, out in zip(examples['instruction'], examples['output']) + for inst, out in zip(examples["instruction"], examples["output"]) ] else: # Try to concatenate all string fields keys = [k for k in examples.keys() if isinstance(examples[k][0], str)] - texts = [' '.join([examples[k][i] for k in keys]) for i in range(len(examples[keys[0]]))] - + texts = [ + " ".join([examples[k][i] for k in keys]) + for i in range(len(examples[keys[0]])) + ] + return self.tokenizer( texts, truncation=True, max_length=self.config.max_seq_length, - padding=False, # Use DataCollator for dynamic padding + padding=False, # Use DataCollator for dynamic padding ) - + tokenized = dataset.map( tokenize_function, batched=True, remove_columns=dataset.column_names, - load_from_cache_file=False, # Avoid hash warnings + load_from_cache_file=False, # Avoid hash warnings desc="Tokenizing dataset", ) - + return tokenized - + def export( self, format: Optional[str] = None, @@ -1165,13 +1301,13 @@ def export( ) -> str: """ Export model to various formats. - + Supported formats: - "gguf": For llama.cpp, Ollama, LM Studio (Q4_K_M, Q5_K_M, etc.) - "safetensors": For HuggingFace ecosystem - "onnx": For ONNX Runtime, TensorRT - "mlx": For Apple Silicon Macs - + Args: format: Target format (gguf, safetensors, onnx, mlx). Uses shared config when omitted. output_path: Output file/directory path @@ -1180,10 +1316,10 @@ def export( - ONNX: dynamic, int8 - MLX: 4bit, 8bit **kwargs: Format-specific options - + Returns: Path to exported model - + Example: >>> model.export("gguf") # Uses auto name >>> model.export("gguf", "my_model.gguf", quantization="Q4_K_M") @@ -1193,24 +1329,26 @@ def export( format = ( format if format is not None - else self.export_push_config.get("format", DEFAULT_EXPORT_PUSH_CONFIG["format"]) + else self.export_push_config.get( + "format", DEFAULT_EXPORT_PUSH_CONFIG["format"] + ) ).lower() effective_quantization = quantization if effective_quantization is None and format == "gguf": effective_quantization = self.export_push_config.get( "quantization", DEFAULT_EXPORT_PUSH_CONFIG["quantization"] ) - + # Merge LoRA if applied if self._lora_applied: if self.verbose: print_info("Merging LoRA weights before export...") self.model = self.model.merge_and_unload() self._lora_applied = False - + # Auto-generate output path if output_path is None: - model_name = self.model.config._name_or_path.split('/')[-1] + model_name = self.model.config._name_or_path.split("/")[-1] if format == "gguf": quant = effective_quantization output_path = f"{model_name}.{quant.upper()}.gguf" @@ -1222,7 +1360,7 @@ def export( output_path = f"./{model_name}-mlx/" else: output_path = f"./{model_name}-{format}/" - + exporters = { "gguf": self._export_gguf, "safetensors": self._export_safetensors, @@ -1230,12 +1368,16 @@ def export( "mlx": self._export_mlx, } if format not in exporters: - raise ValueError(f"Unknown format: {format}. Supported: {list(exporters.keys())}") - + raise ValueError( + f"Unknown format: {format}. Supported: {list(exporters.keys())}" + ) + print_header(f"Exporting to {format.upper()}") - result = exporters[format](output_path, quantization=effective_quantization, **kwargs) + result = exporters[format]( + output_path, quantization=effective_quantization, **kwargs + ) print_success(f"Exported to: {result}") - + return result def push_to_hub( @@ -1246,11 +1388,11 @@ def push_to_hub( quantization: Optional[str] = None, commit_message: str = "Upload model via QuantLLM", license: str = "apache-2.0", - **kwargs + **kwargs, ): """ Push model to HuggingFace Hub with proper model card. - + Args: repo_id: Repository ID (e.g. "username/model") token: HF Token @@ -1259,39 +1401,41 @@ def push_to_hub( commit_message: Commit message license: License type (default: apache-2.0) **kwargs: Arguments for export - + Supported formats: - safetensors: Standard HuggingFace format - gguf: For llama.cpp, Ollama, LM Studio - onnx: For ONNX Runtime, TensorRT - mlx: For Apple Silicon (requires macOS) - + The model card will be automatically generated with: - Proper YAML frontmatter for HuggingFace - Format-specific usage examples - "Use this model" button compatibility """ from ..hub import QuantLLMHubManager - + format_lower = ( format if format is not None - else self.export_push_config.get("push_format", DEFAULT_EXPORT_PUSH_CONFIG["push_format"]) + else self.export_push_config.get( + "push_format", DEFAULT_EXPORT_PUSH_CONFIG["push_format"] + ) ).lower() push_quantization = quantization or self.export_push_config.get( "push_quantization", DEFAULT_EXPORT_PUSH_CONFIG["push_quantization"] ) - + # Get the original base model name (full path for HuggingFace link) base_model_full = self.model.config._name_or_path - model_name = base_model_full.split('/')[-1] - + model_name = base_model_full.split("/")[-1] + print_header(f"Pushing to {repo_id}") print_info(f"Format: {format_lower.upper()}") print_info(f"Base model: {base_model_full}") - + manager = QuantLLMHubManager(repo_id=repo_id, hf_token=token) - + if format_lower == "gguf": # Export GGUF directly to staging quant_label = push_quantization or self.export_push_config.get( @@ -1299,65 +1443,75 @@ def push_to_hub( ) filename = f"{model_name}.{quant_label.upper()}.gguf" save_path = os.path.join(manager.staging_dir, filename) - - self.export(format="gguf", output_path=save_path, quantization=quant_label, **kwargs) - - manager.track_hyperparameters({ - "format": "gguf", - "quantization": quant_label.upper(), - "base_model": base_model_full, - "license": license, - }) + + self.export( + format="gguf", output_path=save_path, quantization=quant_label, **kwargs + ) + + manager.track_hyperparameters( + { + "format": "gguf", + "quantization": quant_label.upper(), + "base_model": base_model_full, + "license": license, + } + ) manager._generate_model_card(format="gguf") - + elif format_lower == "onnx": # Export to ONNX format print_info("Exporting to ONNX format...") save_path = manager.staging_dir - + self._export_onnx(save_path, quantization=push_quantization, **kwargs) - - manager.track_hyperparameters({ - "format": "onnx", - "quantization": push_quantization, - "base_model": base_model_full, - "license": license, - }) + + manager.track_hyperparameters( + { + "format": "onnx", + "quantization": push_quantization, + "base_model": base_model_full, + "license": license, + } + ) manager._generate_model_card(format="onnx") - + elif format_lower == "mlx": # Export to MLX format print_info("Exporting to MLX format...") save_path = manager.staging_dir - + self._export_mlx(save_path, quantization=push_quantization, **kwargs) - - manager.track_hyperparameters({ - "format": "mlx", - "quantization": push_quantization, - "base_model": base_model_full, - "license": license, - }) + + manager.track_hyperparameters( + { + "format": "mlx", + "quantization": push_quantization, + "base_model": base_model_full, + "license": license, + } + ) manager._generate_model_card(format="mlx") - + else: # SafeTensors or PyTorch format - manager.track_hyperparameters({ - "format": format_lower, - "base_model": base_model_full, - "license": license, - }) + manager.track_hyperparameters( + { + "format": format_lower, + "base_model": base_model_full, + "license": license, + } + ) manager.save_final_model(self, format=format_lower) manager._generate_model_card(format=format_lower) - + manager.push(commit_message=commit_message) - + # Alias for convenience push = push_to_hub - + def _export_gguf( - self, - output_path: str, + self, + output_path: str, quantization: Optional[str] = None, fast_mode: bool = False, chunked_conversion: bool = False, @@ -1365,19 +1519,19 @@ def _export_gguf( smart_tensor_ordering: bool = False, disk_offloading: bool = False, disk_offload_dir: Optional[str] = None, - **kwargs + **kwargs, ) -> str: """ Export to GGUF format using optimized llama.cpp converter. - + Automatically installs and configures llama.cpp tools. Handles BitsAndBytes quantized models by dequantizing first. - + Flow: 1. Save model to temp directory (dequantize if needed) 2. Convert to F16 GGUF using convert_hf_to_gguf.py 3. Quantize to target format (Q4_K_M, Q5_K_M, etc.) using llama-quantize - + Args: output_path: Output file path for GGUF quantization: Quantization type (Q4_K_M, Q5_K_M, Q8_0, etc.) @@ -1388,71 +1542,87 @@ def _export_gguf( disk_offloading: Use a dedicated temp/offload directory for intermediate artifacts disk_offload_dir: Directory used when disk_offloading=True """ - from ..quant import convert_to_gguf, quantize_gguf, ensure_llama_cpp_installed, GGUF_QUANT_TYPES - from ..utils import QuantLLMProgress, format_time, format_size import time - + + from ..quant import ( + GGUF_QUANT_TYPES, + convert_to_gguf, + ensure_llama_cpp_installed, + quantize_gguf, + ) + from ..utils import QuantLLMProgress, format_size, format_time + start_time = time.time() - + effective_shard_size = max_shard_size or ( DEFAULT_CHUNKED_SHARD_SIZE if chunked_conversion else None ) - + quant_type = quantization or self.config.quant_type or "q4_k_m" quant_type_upper = quant_type.upper() quant_type_lower = quant_type.lower() - + # Check if this is a passthrough format (f16, bf16, f32 - no quantization needed) - passthrough_types = {'f16', 'f32', 'bf16', 'float16', 'float32', 'bfloat16'} + passthrough_types = {"f16", "f32", "bf16", "float16", "float32", "bfloat16"} needs_quantization = quant_type_lower not in passthrough_types - + if self.verbose: print_info(f"Target quantization: {quant_type_upper}") if fast_mode: print_info("Fast mode enabled") if chunked_conversion: - print_info(f"Chunked conversion enabled (max_shard_size={effective_shard_size})") + print_info( + f"Chunked conversion enabled (max_shard_size={effective_shard_size})" + ) if smart_tensor_ordering: print_info("Smart tensor ordering enabled") - print_warning("Smart tensor ordering may temporarily materialize a full state dict in memory.") + print_warning( + "Smart tensor ordering may temporarily materialize a full state dict in memory." + ) if disk_offloading: - print_info(f"Disk offloading enabled ({disk_offload_dir or 'system temp'})") - + print_info( + f"Disk offloading enabled ({disk_offload_dir or 'system temp'})" + ) + # Ensure llama.cpp if self.verbose: print_info("Checking llama.cpp installation...") ensure_llama_cpp_installed() - + # Check if model is BitsAndBytes quantized and needs dequantization model_to_save = self.model is_bnb_quantized = self._is_bnb_quantized() - + if is_bnb_quantized: if self.verbose: - print_warning("Model is BitsAndBytes quantized. Dequantizing for GGUF export...") - print_info("This may use significant memory. For large models, consider loading with quantize=False.") - + print_warning( + "Model is BitsAndBytes quantized. Dequantizing for GGUF export..." + ) + print_info( + "This may use significant memory. For large models, consider loading with quantize=False." + ) + model_to_save = self._dequantize_model() if self.verbose: print_success("Model dequantized successfully!") - + # Determine dtype for initial conversion (always F16 for best quality) model_dtype = "f16" - + # Get model name for file naming - model_name = self.model.config._name_or_path.split('/')[-1] - + model_name = self.model.config._name_or_path.split("/")[-1] + temp_parent = disk_offload_dir if disk_offloading else None if temp_parent: os.makedirs(temp_parent, exist_ok=True) - + # Create temp dir for conversion with tempfile.TemporaryDirectory(dir=temp_parent) as temp_dir: # Step 1: Save model to temp directory if self.verbose: print_header("Step 1/3: Saving Model", icon="💾") print_info(f"Staging model to {temp_dir}...") - + with QuantLLMProgress() as progress: task = progress.add_task("Saving model weights...", total=None) save_kwargs = { @@ -1460,95 +1630,101 @@ def _export_gguf( } if effective_shard_size: save_kwargs["max_shard_size"] = effective_shard_size - + if smart_tensor_ordering: - save_kwargs["state_dict"] = memory_optimized_tensor_order(model_to_save.state_dict()) - + save_kwargs["state_dict"] = memory_optimized_tensor_order( + model_to_save.state_dict() + ) + try: model_to_save.save_pretrained(temp_dir, **save_kwargs) except Exception as e: if self.verbose: - print_warning(f"SafeTensors save failed ({e}), using PyTorch format...") + print_warning( + f"SafeTensors save failed ({e}), using PyTorch format..." + ) save_kwargs["safe_serialization"] = False model_to_save.save_pretrained(temp_dir, **save_kwargs) - + self.tokenizer.save_pretrained(temp_dir) progress.update(task, completed=100) - + if self.verbose: print_success("Model saved to staging area!") - + # Step 2: Convert to F16 GGUF if self.verbose: print_header("Step 2/3: Converting to GGUF", icon="🔄") - + # F16 intermediate file (or final if no quantization needed) if needs_quantization: f16_gguf_file = os.path.join(temp_dir, f"{model_name}.F16.gguf") else: f16_gguf_file = f"{model_name}.{quant_type_upper}.gguf" - + output_files, _ = convert_to_gguf( model_name=model_name, input_folder=temp_dir, model_dtype=model_dtype, quantization_type="f16" if needs_quantization else quant_type_lower, - print_output=self.verbose + print_output=self.verbose, ) - + if not output_files: raise RuntimeError("GGUF conversion failed to produce output file.") - + f16_file = output_files[0] - + # If conversion produced a different name, use that if os.path.exists(f16_file): f16_gguf_file = f16_file - + if self.verbose: print_success(f"F16 GGUF created: {f16_gguf_file}") - + # Step 3: Apply quantization if needed if needs_quantization: if self.verbose: - print_header(f"Step 3/3: Quantizing to {quant_type_upper}", icon="⚡") + print_header( + f"Step 3/3: Quantizing to {quant_type_upper}", icon="⚡" + ) print_info(f"Applying {quant_type_upper} quantization...") - + # Final quantized output quantized_file = f"{model_name}.{quant_type_upper}.gguf" - + quantize_gguf( input_gguf=f16_gguf_file, output_gguf=quantized_file, quant_type=quant_type_upper, - print_output=self.verbose + print_output=self.verbose, ) - + final_file = quantized_file - + # Clean up intermediate F16 file if os.path.exists(f16_gguf_file) and f16_gguf_file != quantized_file: os.remove(f16_gguf_file) - + if self.verbose: print_success(f"Quantization complete: {quantized_file}") else: final_file = f16_gguf_file if self.verbose: print_info("No quantization needed (already in target format)") - + # Move to output path if different if os.path.abspath(final_file) != os.path.abspath(output_path): if self.verbose: print_info(f"Moving {final_file} → {output_path}") shutil.move(final_file, output_path) - + # Clean up dequantized model if created if is_bnb_quantized and model_to_save is not self.model: del model_to_save if torch.cuda.is_available(): torch.cuda.empty_cache() - + # Print final summary elapsed = time.time() - start_time if self.verbose: @@ -1558,55 +1734,60 @@ def _export_gguf( print_info(f"Format: GGUF {quant_type_upper}") print_info(f"Size: {format_size(file_size_bytes)}") print_info(f"Time: {format_time(elapsed)}") - + return output_path - + def _is_bnb_quantized(self) -> bool: """Check if model is BitsAndBytes quantized.""" # Check config for quantization_config - if hasattr(self.model, 'config'): - quant_config = getattr(self.model.config, 'quantization_config', None) + if hasattr(self.model, "config"): + quant_config = getattr(self.model.config, "quantization_config", None) if quant_config: # Check if it's BitsAndBytes - quant_method = getattr(quant_config, 'quant_method', None) - if quant_method in ['bitsandbytes', 'bnb']: + quant_method = getattr(quant_config, "quant_method", None) + if quant_method in ["bitsandbytes", "bnb"]: return True - if getattr(quant_config, 'load_in_4bit', False): + if getattr(quant_config, "load_in_4bit", False): return True - if getattr(quant_config, 'load_in_8bit', False): + if getattr(quant_config, "load_in_8bit", False): return True - + # Check for BNB linear layers in the model try: import bitsandbytes as bnb + for module in self.model.modules(): if isinstance(module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): return True except ImportError: pass - + return False - + def _dequantize_model(self) -> nn.Module: """ Dequantize a BitsAndBytes model to full precision for GGUF export. - + Returns: Dequantized model in float16/bfloat16 """ import gc - + # Get the model name for reloading - model_name = getattr(self.model.config, '_name_or_path', None) - + model_name = getattr(self.model.config, "_name_or_path", None) + if model_name: # Best approach: Reload model in full precision if self.verbose: print_info(f"Reloading {model_name} in full precision...") - + # Determine target dtype - target_dtype = self.config.dtype if self.config.dtype in [torch.float16, torch.bfloat16] else torch.float16 - + target_dtype = ( + self.config.dtype + if self.config.dtype in [torch.float16, torch.bfloat16] + else torch.float16 + ) + try: dequant_model = AutoModelForCausalLM.from_pretrained( model_name, @@ -1620,60 +1801,76 @@ def _dequantize_model(self) -> nn.Module: if self.verbose: print_warning(f"Failed to reload model: {e}") print_info("Attempting in-place dequantization...") - + # Fallback: In-place dequantization (less reliable but works for some models) try: import bitsandbytes as bnb - - target_dtype = self.config.dtype if self.config.dtype in [torch.float16, torch.bfloat16] else torch.float16 - + + target_dtype = ( + self.config.dtype + if self.config.dtype in [torch.float16, torch.bfloat16] + else torch.float16 + ) + # Create a copy of the model state dict with dequantized weights dequant_model = AutoModelForCausalLM.from_config( self.model.config, torch_dtype=target_dtype, ) - + # Copy and dequantize weights with torch.no_grad(): for name, module in self.model.named_modules(): if isinstance(module, bnb.nn.Linear4bit): # Dequantize 4-bit weights target_module = dict(dequant_model.named_modules()).get(name) - if target_module is not None and hasattr(target_module, 'weight'): + if target_module is not None and hasattr( + target_module, "weight" + ): # Get dequantized weight weight = module.weight - if hasattr(weight, 'dequantize'): + if hasattr(weight, "dequantize"): dequant_weight = weight.dequantize() else: # Manual dequantization for older versions dequant_weight = bnb.functional.dequantize_4bit( weight.data, weight.quant_state ) - target_module.weight.data.copy_(dequant_weight.to(target_dtype)) - + target_module.weight.data.copy_( + dequant_weight.to(target_dtype) + ) + if module.bias is not None: - target_module.bias.data.copy_(module.bias.data.to(target_dtype)) - + target_module.bias.data.copy_( + module.bias.data.to(target_dtype) + ) + elif isinstance(module, bnb.nn.Linear8bitLt): # Dequantize 8-bit weights target_module = dict(dequant_model.named_modules()).get(name) - if target_module is not None and hasattr(target_module, 'weight'): + if target_module is not None and hasattr( + target_module, "weight" + ): weight = module.weight - if hasattr(weight, 'dequantize'): + if hasattr(weight, "dequantize"): dequant_weight = weight.dequantize() else: dequant_weight = weight.data.to(target_dtype) - target_module.weight.data.copy_(dequant_weight.to(target_dtype)) - + target_module.weight.data.copy_( + dequant_weight.to(target_dtype) + ) + if module.bias is not None: - target_module.bias.data.copy_(module.bias.data.to(target_dtype)) - + target_module.bias.data.copy_( + module.bias.data.to(target_dtype) + ) + gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() - + return dequant_model - + except Exception as e: raise RuntimeError( f"Failed to dequantize BitsAndBytes model: {e}\n\n" @@ -1681,45 +1878,42 @@ def _dequantize_model(self) -> nn.Module: " model = TurboModel.from_pretrained('your-model', quantize=False)\n" " model.export('gguf', quantization='Q4_K_M')" ) - - def _export_safetensors( - self, - output_path: str, - **kwargs - ) -> str: + + def _export_safetensors(self, output_path: str, **kwargs) -> str: """Export to safetensors format.""" os.makedirs(output_path, exist_ok=True) self.model.save_pretrained(output_path, safe_serialization=True) self.tokenizer.save_pretrained(output_path) return output_path - + def _export_onnx( self, output_path: str, quantization: Optional[str] = None, opset_version: int = 14, - **kwargs + **kwargs, ) -> str: """ Export to ONNX format with proper structure. - + Uses Optimum's ONNX exporter which properly handles LLMs like Llama. torch.onnx.export does NOT work for modern LLMs due to dynamic attention. - + Args: output_path: Output directory for ONNX files quantization: ONNX quantization type (dynamic, static, int8, avx2, avx512) opset_version: ONNX opset version (default: 14) """ from ..utils import QuantLLMProgress - + # Check for required dependencies try: from optimum.onnxruntime import ORTModelForCausalLM + HAS_OPTIMUM = True except ImportError: HAS_OPTIMUM = False - + if not HAS_OPTIMUM: # Cannot export LLMs without Optimum - torch.onnx.export doesn't work error_msg = """ @@ -1737,23 +1931,27 @@ def _export_onnx( pip install quantllm[onnx] """ print_error(error_msg) - raise ImportError("ONNX export requires: pip install onnx onnxruntime optimum[onnxruntime] onnxscript") - + raise ImportError( + "ONNX export requires: pip install onnx onnxruntime optimum[onnxruntime] onnxscript" + ) + os.makedirs(output_path, exist_ok=True) model_name = self.model.config._name_or_path - + if self.verbose: print_info("Using Optimum for ONNX export...") - + with QuantLLMProgress() as progress: task = progress.add_task("Exporting to ONNX...", total=None) - + try: # Check if model is quantized - need to export from original if self._is_bnb_quantized(): if self.verbose: - print_info("BNB quantized model detected. Exporting from original HuggingFace model...") - + print_info( + "BNB quantized model detected. Exporting from original HuggingFace model..." + ) + # Export directly from HuggingFace (not our quantized version) ort_model = ORTModelForCausalLM.from_pretrained( model_name, @@ -1764,11 +1962,11 @@ def _export_onnx( # Save model first, then convert temp_path = os.path.join(output_path, "_temp_hf") os.makedirs(temp_path, exist_ok=True) - + try: self.model.save_pretrained(temp_path, safe_serialization=True) self.tokenizer.save_pretrained(temp_path) - + ort_model = ORTModelForCausalLM.from_pretrained( temp_path, export=True, @@ -1777,82 +1975,110 @@ def _export_onnx( finally: # Clean temp shutil.rmtree(temp_path, ignore_errors=True) - + # Save ONNX model ort_model.save_pretrained(output_path) self.tokenizer.save_pretrained(output_path) - + except Exception as e: progress.update(task, completed=100) error_str = str(e) - + # Check for common issues and provide helpful messages if "onnxscript" in error_str.lower(): - print_error("Missing onnxscript package. Install with: pip install onnxscript") - raise ImportError("ONNX export requires onnxscript: pip install onnxscript") from e - elif "cannot export" in error_str.lower() or "unsupported" in error_str.lower(): - print_error(f"Model architecture may not support ONNX export: {error_str}") + print_error( + "Missing onnxscript package. Install with: pip install onnxscript" + ) + raise ImportError( + "ONNX export requires onnxscript: pip install onnxscript" + ) from e + elif ( + "cannot export" in error_str.lower() + or "unsupported" in error_str.lower() + ): + print_error( + f"Model architecture may not support ONNX export: {error_str}" + ) raise else: raise - + progress.update(task, completed=100) - + # Apply quantization if requested if quantization: if self.verbose: print_info(f"Applying {quantization} ONNX quantization...") self._quantize_onnx_model(output_path, quantization) - + if self.verbose: print_success(f"ONNX model exported to {output_path}") - + return output_path - + def _quantize_onnx_model(self, model_path: str, quant_type: str) -> None: """ Apply ONNX quantization. - + ONNX supports INT8 (8-bit integer) quantization only. Unlike GGUF, ONNX doesn't support 2/3/4/5/6-bit quantization. - + Args: model_path: Path to ONNX model directory quant_type: Quantization type: - Bit-based: "8", "8bit", "int8" → INT8 quantization - Platform: "avx2", "avx512", "arm64" → Platform-optimized INT8 - Type: "dynamic", "static" → Quantization method - + Note: Requests for 4-bit or other bit widths will use INT8 with a warning. """ try: from optimum.onnxruntime import ORTQuantizer from optimum.onnxruntime.configuration import AutoQuantizationConfig - + quantizer = ORTQuantizer.from_pretrained(model_path) - + # Normalize quantization type quant_lower = quant_type.lower().replace("_", "").replace("-", "") - + # Check for bit-based requests (ONNX only supports 8-bit) bit_request = None - for bit_pattern in ["2bit", "3bit", "4bit", "5bit", "6bit", "q2", "q3", "q4", "q5", "q6"]: + for bit_pattern in [ + "2bit", + "3bit", + "4bit", + "5bit", + "6bit", + "q2", + "q3", + "q4", + "q5", + "q6", + ]: if bit_pattern in quant_lower: bit_request = bit_pattern break - + if bit_request: - print_warning(f"ONNX only supports INT8 (8-bit) quantization, not {quant_type}.") + print_warning( + f"ONNX only supports INT8 (8-bit) quantization, not {quant_type}." + ) print_info("For lower bit quantization, use GGUF format instead.") print_info("Proceeding with INT8 quantization...") - + # Determine optimal config based on platform or explicit request if "avx512" in quant_lower or "vnni" in quant_lower: - qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=True) + qconfig = AutoQuantizationConfig.avx512_vnni( + is_static=False, per_channel=True + ) if self.verbose: - print_info("Using AVX512 VNNI INT8 quantization (Intel Xeon/Ice Lake+)") + print_info( + "Using AVX512 VNNI INT8 quantization (Intel Xeon/Ice Lake+)" + ) elif "arm64" in quant_lower or "arm" in quant_lower: - qconfig = AutoQuantizationConfig.arm64(is_static=False, per_channel=True) + qconfig = AutoQuantizationConfig.arm64( + is_static=False, per_channel=True + ) if self.verbose: print_info("Using ARM64 INT8 quantization (Apple Silicon/ARM)") elif "static" in quant_lower: @@ -1865,69 +2091,72 @@ def _quantize_onnx_model(self, model_path: str, quant_type: str) -> None: qconfig = AutoQuantizationConfig.avx2(is_static=False, per_channel=True) if self.verbose: print_info("Using dynamic INT8 quantization (AVX2)") - + # Apply quantization quantizer.quantize(save_dir=model_path, quantization_config=qconfig) - + if self.verbose: print_success("ONNX INT8 quantization applied successfully") - + except ImportError: - print_warning("Optimum quantization not available. Skipping ONNX quantization.") + print_warning( + "Optimum quantization not available. Skipping ONNX quantization." + ) print_info("Install with: pip install optimum[onnxruntime]") except Exception as e: print_warning(f"ONNX quantization failed: {e}") print_info("The unquantized ONNX model is still available.") - + def _export_mlx( - self, - output_path: str, - quantization: Optional[str] = None, - **kwargs + self, output_path: str, quantization: Optional[str] = None, **kwargs ) -> str: """ Export to MLX format for Apple Silicon. - + MLX supports 4-bit and 8-bit quantization only. - + Args: output_path: Output directory quantization: MLX quantization options: - "4bit", "4", "Q4", "Q4_K_M" → 4-bit quantization - "8bit", "8", "Q8" → 8-bit quantization - None → No quantization (FP16) - + Note: 2-bit, 3-bit, 5-bit, 6-bit requests will map to closest (4 or 8-bit). """ - from ..utils import QuantLLMProgress - import subprocess - import sys - # Check platform import platform + import subprocess + import sys + + from ..utils import QuantLLMProgress + if platform.system() != "Darwin" or platform.machine() != "arm64": print_warning("MLX export is optimized for Apple Silicon Macs.") - print_info("The model will be saved but may not run efficiently on this system.") - + print_info( + "The model will be saved but may not run efficiently on this system." + ) + try: import mlx + HAS_MLX = True except ImportError: HAS_MLX = False - + os.makedirs(output_path, exist_ok=True) model_name = self.model.config._name_or_path - + if HAS_MLX: try: from mlx_lm import convert - + if self.verbose: print_info("Using mlx-lm for conversion...") - + with QuantLLMProgress() as progress: task = progress.add_task("Converting to MLX...", total=None) - + # Save HF model first if quantized if self._is_bnb_quantized(): # Use original model name @@ -1937,25 +2166,31 @@ def _export_mlx( os.makedirs(source_path, exist_ok=True) self.model.save_pretrained(source_path) self.tokenizer.save_pretrained(source_path) - + # Build convert arguments convert_args = { "hf_path": source_path, "mlx_path": output_path, } - + # Parse quantization request if quantization: - quant_lower = quantization.lower().replace("_", "").replace("-", "") - + quant_lower = ( + quantization.lower().replace("_", "").replace("-", "") + ) + # MLX only supports 4-bit and 8-bit if any(x in quant_lower for x in ["2", "3"]): - print_warning(f"MLX only supports 4-bit and 8-bit, not {quantization}.") + print_warning( + f"MLX only supports 4-bit and 8-bit, not {quantization}." + ) print_info("Using 4-bit quantization (smallest available).") convert_args["quantize"] = True convert_args["q_bits"] = 4 elif any(x in quant_lower for x in ["5", "6", "7"]): - print_warning(f"MLX only supports 4-bit and 8-bit, not {quantization}.") + print_warning( + f"MLX only supports 4-bit and 8-bit, not {quantization}." + ) print_info("Using 8-bit quantization (closest available).") convert_args["quantize"] = True convert_args["q_bits"] = 8 @@ -1975,16 +2210,16 @@ def _export_mlx( convert_args["q_bits"] = 4 if self.verbose: print_info("Using 4-bit MLX quantization (default)") - + # Run conversion convert(**convert_args) - + # Clean temp if not self._is_bnb_quantized(): shutil.rmtree(source_path, ignore_errors=True) - + progress.update(task, completed=100) - + except Exception as e: print_error(f"MLX conversion failed: {e}") raise @@ -1992,11 +2227,13 @@ def _export_mlx( # Fallback: save as HF format with instructions if self.verbose: print_warning("mlx-lm not installed. Saving as HuggingFace format.") - print_info("To convert to MLX: pip install mlx-lm && python -m mlx_lm.convert ...") - + print_info( + "To convert to MLX: pip install mlx-lm && python -m mlx_lm.convert ..." + ) + self.model.save_pretrained(output_path) self.tokenizer.save_pretrained(output_path) - + # Create README with conversion instructions readme_path = os.path.join(output_path, "CONVERT_TO_MLX.md") with open(readme_path, "w") as f: @@ -2005,14 +2242,16 @@ def _export_mlx( f.write("To convert to MLX format on Apple Silicon:\n\n") f.write("```bash\n") f.write("pip install mlx-lm\n") - f.write(f"python -m mlx_lm.convert --hf-path {output_path} --mlx-path ./mlx_model\n") + f.write( + f"python -m mlx_lm.convert --hf-path {output_path} --mlx-path ./mlx_model\n" + ) f.write("```\n") - + if self.verbose: print_success(f"MLX model exported to {output_path}") - + return output_path - + def __repr__(self) -> str: params = self.model.num_parameters() / 1e9 return ( @@ -2025,49 +2264,52 @@ def __repr__(self) -> str: f")" ) - def optimize_inference(self, backend: str = "triton", bits: int = 4): """ Optimize model for inference using high-performance kernels. - + Args: backend: Optimization backend ("triton") bits: Quantization bits (4 or 8) """ if backend == "triton": from ..kernels.triton import TritonQuantizedLinear, is_triton_available + if not is_triton_available(): - print_warning("Triton is not available or no GPU detected. Skipping optimization.") + print_warning( + "Triton is not available or no GPU detected. Skipping optimization." + ) return - + if self.verbose: print_header("Optimizing with Triton Kernels ⚡") - + count = self._replace_with_triton(self.model, bits) - + if self.verbose: print_success(f"Optimized {count} layers with Triton fused kernels!") - + def _replace_with_triton(self, module: nn.Module, bits: int) -> int: """Recursively replace Linear layers with TritonQuantizedLinear.""" from ..kernels.triton import TritonQuantizedLinear + count = 0 for name, child in module.named_children(): if isinstance(child, nn.Linear): # Replace with Triton Linear if self.verbose: print_info(f"Quantizing {name}...") - + quantized = TritonQuantizedLinear( - child.in_features, - child.out_features, - bits=bits, + child.in_features, + child.out_features, + bits=bits, bias=child.bias is not None, - group_size=128 + group_size=128, ) quantized.to(child.weight.device) quantized.quantize_from(child) - + setattr(module, name, quantized) count += 1 else: @@ -2083,7 +2325,7 @@ def register_architecture( ) -> None: """ Register a new architecture alias and optional explicit model class. - + Example: >>> register_architecture("my-new-model", base_model_type="llama") """ @@ -2101,43 +2343,45 @@ def turbo( max_length: Optional[int] = None, device: Optional[str] = None, dtype: Optional[str] = None, + base_model_fallback: bool = True, config: Optional[Dict[str, Any]] = None, **kwargs, ) -> TurboModel: """ Load and quantize any LLM in one line. - + This is the simplest way to use QuantLLM. Everything is automatically configured based on your hardware. - + Args: model: HuggingFace model name or local path bits: Override quantization bits (default: auto) max_length: Override max sequence length (default: auto) device: Override device (default: best GPU) dtype: Override dtype (default: bf16/fp16) + base_model_fallback: Retry with resolved base model config on first-load failure config: Shared export/push config (format, quantization, push_format, etc.) **kwargs: Additional options passed to from_pretrained - + Returns: TurboModel ready for use - + Examples: >>> # Simplest usage - everything automatic >>> model = turbo("meta-llama/Llama-3-8B") - >>> + >>> >>> # Override quantization >>> model = turbo("mistralai/Mistral-7B", bits=4) - >>> + >>> >>> # For long context >>> model = turbo("Qwen/Qwen2-72B", max_length=32768) - >>> + >>> >>> # Generate text >>> print(model.generate("Hello, world!")) - >>> + >>> >>> # Fine-tune >>> model.finetune("my_data.json") - >>> + >>> >>> # Export >>> model.export("gguf") """ @@ -2147,6 +2391,7 @@ def turbo( max_length=max_length, device=device, dtype=dtype, + base_model_fallback=base_model_fallback, config=config, **kwargs, ) diff --git a/tests/test_architecture_fallback.py b/tests/test_architecture_fallback.py index ade044b..8371981 100644 --- a/tests/test_architecture_fallback.py +++ b/tests/test_architecture_fallback.py @@ -3,8 +3,8 @@ import transformers -from quantllm.core.turbo_model import TurboModel import quantllm.core.turbo_model as turbo_model_module +from quantllm.core.turbo_model import TurboModel class _DummySmartConfig(SimpleNamespace): @@ -155,3 +155,112 @@ def from_config(cls, *args, **kwargs): assert _FakeAutoModel.called_from_pretrained is False assert _FakeAutoModel.called_from_config is True assert loaded.model.config.model_type == "llama" + + +def test_trust_remote_code_warns_for_unregistered_architecture(monkeypatch, caplog): + monkeypatch.setattr(TurboModel, "_architecture_registry", {}) + monkeypatch.setattr(TurboModel, "_model_class_registry", {}) + monkeypatch.setattr( + turbo_model_module.SmartConfig, + "detect", + lambda *args, **kwargs: _make_smart_config(), + ) + monkeypatch.setattr( + turbo_model_module.AutoTokenizer, + "from_pretrained", + lambda *args, **kwargs: _make_tokenizer(), + ) + monkeypatch.setattr( + transformers.AutoConfig, + "from_pretrained", + lambda *args, **kwargs: SimpleNamespace( + model_type="newmodel", + quantization_config=None, + ), + ) + + class _FakeAutoModel: + @staticmethod + def from_pretrained(*args, **kwargs): + if "config" in kwargs: + return SimpleNamespace(config=SimpleNamespace(model_type="llama")) + raise ValueError("Unrecognized configuration class") + + monkeypatch.setattr( + turbo_model_module, + "AutoModelForCausalLM", + _FakeAutoModel, + ) + + with caplog.at_level("WARNING"): + loaded = TurboModel.from_pretrained( + "org/newmodel-7b", + quantize=False, + verbose=False, + base_model_fallback=True, + trust_remote_code=True, + ) + + assert loaded.model.config.model_type == "llama" + assert ( + "trust_remote_code=True is enabled for unregistered architecture 'newmodel'" + in caplog.text + ) + + +def test_quantization_kwargs_are_preserved_during_fallback(monkeypatch): + monkeypatch.setattr(TurboModel, "_architecture_registry", {}) + monkeypatch.setattr(TurboModel, "_model_class_registry", {}) + smart_config = _make_smart_config() + smart_config.bits = 4 + monkeypatch.setattr( + turbo_model_module.SmartConfig, + "detect", + lambda *args, **kwargs: smart_config, + ) + monkeypatch.setattr( + turbo_model_module.AutoTokenizer, + "from_pretrained", + lambda *args, **kwargs: _make_tokenizer(), + ) + monkeypatch.setattr( + transformers.AutoConfig, + "from_pretrained", + lambda *args, **kwargs: SimpleNamespace( + model_type="newmodel", + quantization_config=None, + ), + ) + monkeypatch.setattr( + TurboModel, + "_get_quantization_kwargs", + classmethod(lambda cls, cfg: {"quantization_config": "nf4-sentinel"}), + ) + + calls = [] + + class _FakeAutoModel: + @staticmethod + def from_pretrained(*args, **kwargs): + calls.append(kwargs) + if len(calls) == 1: + raise ValueError("Unrecognized configuration class") + return SimpleNamespace(config=SimpleNamespace(model_type="llama")) + + monkeypatch.setattr( + turbo_model_module, + "AutoModelForCausalLM", + _FakeAutoModel, + ) + + loaded = TurboModel.from_pretrained( + "org/newmodel-7b", + quantize=True, + verbose=False, + base_model_fallback=True, + ) + + assert loaded.model.config.model_type == "llama" + assert len(calls) == 2 + assert calls[0]["quantization_config"] == "nf4-sentinel" + assert calls[1]["quantization_config"] == "nf4-sentinel" From 0b877a0df8adee106d1c64d662075d5dd04ab672 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 25 Apr 2026 13:29:37 +0000 Subject: [PATCH 5/6] Minimize turbo_model diff while keeping review-requested fallback updates Agent-Logs-Url: https://github.com/codewithdark-git/QuantLLM/sessions/8867f3b4-18ae-4207-b2e8-51444418c7aa Co-authored-by: codewithdark-git <144595403+codewithdark-git@users.noreply.github.com> --- quantllm/core/turbo_model.py | 1226 ++++++++++++++-------------------- 1 file changed, 501 insertions(+), 725 deletions(-) diff --git a/quantllm/core/turbo_model.py b/quantllm/core/turbo_model.py index c355fe0..ffabb37 100644 --- a/quantllm/core/turbo_model.py +++ b/quantllm/core/turbo_model.py @@ -4,38 +4,29 @@ Load, quantize, fine-tune, and export LLMs with one line each. """ -import copy import os import re import shutil import tempfile +import copy from functools import lru_cache -from typing import Any, Dict, List, Optional, Type, Union - +from typing import Optional, Dict, Any, Union, List, Type import torch import torch.nn as nn -from datasets.utils.logging import disable_progress_bar as disable_ds_progress_bar from transformers import ( AutoModelForCausalLM, AutoTokenizer, - GenerationConfig, PreTrainedModel, PreTrainedTokenizer, + GenerationConfig, ) -from transformers.utils.logging import disable_progress_bar as disable_hf_progress_bar -from ..utils import ( - QuantLLMProgress, - logger, - print_error, - print_header, - print_info, - print_success, - print_warning, -) +from .smart_config import SmartConfig from .hardware import HardwareProfiler +from ..utils import logger, print_header, print_success, print_error, print_info, print_warning, QuantLLMProgress +from transformers.utils.logging import disable_progress_bar as disable_hf_progress_bar +from datasets.utils.logging import disable_progress_bar as disable_ds_progress_bar from .memory import memory_optimized_tensor_order -from .smart_config import SmartConfig DEFAULT_CHUNKED_SHARD_SIZE = "2GB" DEFAULT_EXPORT_PUSH_CONFIG = { @@ -57,30 +48,30 @@ class TurboModel: """ High-performance LLM with the simplest possible API. - + Features: - One-line loading with automatic quantization - One-line fine-tuning with LoRA - One-line export to multiple formats - Automatic hardware optimization - + Example: >>> # Load any model in one line >>> model = TurboModel.from_pretrained("meta-llama/Llama-3-8B") - >>> + >>> >>> # Generate text >>> response = model.generate("Explain quantum computing") - >>> + >>> >>> # Fine-tune with your data >>> model.finetune("my_data.json", epochs=3) - >>> + >>> >>> # Export to GGUF >>> model.export("gguf", "my_model.gguf") """ - + _architecture_registry: Dict[str, str] = {} _model_class_registry: Dict[str, Type[PreTrainedModel]] = {} - + def __init__( self, model: PreTrainedModel, @@ -91,7 +82,7 @@ def __init__( ): """ Initialize TurboModel. Use from_pretrained() instead of direct init. - + Args: model: The loaded/quantized model tokenizer: Associated tokenizer @@ -116,7 +107,7 @@ def register_architecture( ) -> None: """ Register a new architecture alias and optional explicit model class. - + Args: architecture: Architecture or model type name to register base_model_type: Base model family to fall back to (e.g., "llama") @@ -125,13 +116,13 @@ def register_architecture( normalized = architecture.lower().strip() if not normalized: raise ValueError("architecture must be a non-empty string") - + if base_model_type: cls._architecture_registry[normalized] = base_model_type.lower().strip() - + if model_class is not None: cls._model_class_registry[normalized] = model_class - + @classmethod def resolve_model_type( cls, @@ -142,26 +133,26 @@ def resolve_model_type( ) -> Optional[str]: """ Resolve model type using override, registry, and default family patterns. - + If config_model_type is provided but unregistered, the original config value is returned unchanged. """ if model_type_override: return model_type_override.lower().strip() - + model_type = (config_model_type or "").lower().strip() if model_type: return cls._architecture_registry.get(model_type, model_type) - + name = model_name.lower() for pattern, fallback in cls._architecture_registry.items(): if cls._matches_model_name_pattern(name, pattern): return fallback - + for pattern, fallback in DEFAULT_ARCHITECTURE_FALLBACKS.items(): if cls._matches_model_name_pattern(name, pattern): return fallback - + return None @classmethod @@ -199,21 +190,15 @@ def _load_model_with_fallback( from_config_only: bool, ) -> PreTrainedModel: """Load model with architecture fallback and optional config-only mode.""" - config_model_type = ( - (getattr(hf_config, "model_type", None) or "").lower().strip() - ) - is_registered_architecture = ( - config_model_type in cls._architecture_registry - if config_model_type - else False - ) + config_model_type = (getattr(hf_config, "model_type", None) or "").lower().strip() + is_registered_architecture = config_model_type in cls._architecture_registry if config_model_type else False resolved_model_type = cls.resolve_model_type( model_name, config_model_type=getattr(hf_config, "model_type", None), model_type_override=model_type_override, ) resolved_config = hf_config - + if hf_config is not None and resolved_model_type: current_model_type = getattr(hf_config, "model_type", None) if current_model_type != resolved_model_type: @@ -232,7 +217,7 @@ def _load_model_with_fallback( config_model_type, resolved_model_type, ) - + if from_config_only: if resolved_config is None: raise ValueError( @@ -244,7 +229,7 @@ def _load_model_with_fallback( trust_remote_code=trust_remote_code, torch_dtype=model_kwargs.get("torch_dtype"), ) - + try: return AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) except Exception as primary_error: @@ -256,12 +241,10 @@ def _load_model_with_fallback( fallback_kwargs = dict(model_kwargs) fallback_kwargs["config"] = resolved_config try: - return AutoModelForCausalLM.from_pretrained( - model_name, **fallback_kwargs - ) + return AutoModelForCausalLM.from_pretrained(model_name, **fallback_kwargs) except Exception as fallback_config_error: fallback_error = fallback_config_error - + if resolved_model_type: registered_cls = cls._model_class_registry.get(resolved_model_type) if registered_cls is not None: @@ -269,18 +252,14 @@ def _load_model_with_fallback( if resolved_config is not None: class_kwargs["config"] = resolved_config try: - return registered_cls.from_pretrained( - model_name, **class_kwargs - ) + return registered_cls.from_pretrained(model_name, **class_kwargs) except Exception as fallback_registered_error: fallback_error = fallback_registered_error - - error_details = ( - f" Last fallback error: {fallback_error}" if fallback_error else "" - ) + + error_details = f" Last fallback error: {fallback_error}" if fallback_error else "" architecture_label = config_model_type or "" resolved_label = resolved_model_type or "" - + raise RuntimeError( "Failed to load model with AutoModelForCausalLM and fallback resolution.\n" f"Architecture '{architecture_label}' resolved to base model type '{resolved_label}'.\n" @@ -288,9 +267,10 @@ def _load_model_with_fallback( f"1) register_architecture('{architecture_label}', base_model_type='llama').\n" "2) Use model_type_override='llama' (or your compatible base family).\n" "3) Use from_config_only=True with a loadable config " - "(usually trust_remote_code=True)." + error_details + "(usually trust_remote_code=True)." + + error_details ) from (fallback_error or primary_error) - + @classmethod def from_pretrained( cls, @@ -313,13 +293,13 @@ def from_pretrained( ) -> "TurboModel": """ Load a model with automatic optimization. - + This is the main entry point for QuantLLM. It automatically: - Detects your hardware capabilities - Chooses optimal quantization settings - Configures memory management - Enables speed optimizations - + Args: model_name: HuggingFace model name or local path bits: Override quantization bits (default: auto-detect) @@ -334,14 +314,14 @@ def from_pretrained( config_override: Dict to override any auto-detected settings config: Shared export/push config (format, quantization, push_format, etc.) verbose: Print loading progress - + Returns: TurboModel ready for inference or fine-tuning - + Example: >>> # Simplest usage - everything automatic >>> model = TurboModel.from_pretrained("meta-llama/Llama-3-8B") - >>> + >>> >>> # With specific bits >>> model = TurboModel.from_pretrained("mistral-7b", bits=4) >>> @@ -351,10 +331,10 @@ def from_pretrained( # Disable default progress bars disable_hf_progress_bar() disable_ds_progress_bar() - + if verbose: print_header(f"Loading {model_name}") - + # Auto-configure everything if verbose: logger.info("🚀 Detecting hardware and configuration...") @@ -366,9 +346,9 @@ def from_pretrained( device=device, dtype=dtype, ) - + from dataclasses import asdict - + # Apply user overrides if config_override: # Handle SmartConfig objects @@ -376,98 +356,86 @@ def from_pretrained( override_dict = asdict(config_override) else: override_dict = config_override - + for key, value in override_dict.items(): if hasattr(smart_config, key): setattr(smart_config, key, value) - + if verbose: smart_config.print_summary() - + # Load tokenizer if verbose: logger.info("📝 Loading tokenizer...") - + tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=trust_remote_code, ) - + # Ensure pad token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id - + # Load model with optimizations if verbose: - if ( - smart_config.bits != smart_config.effective_loading_bits - and smart_config.bits < 16 - ): - logger.info( - f"đŸ“Ļ Loading model ({smart_config.effective_loading_bits}-bit, for {smart_config.bits}-bit GGUF export)..." - ) + if smart_config.bits != smart_config.effective_loading_bits and smart_config.bits < 16: + logger.info(f"đŸ“Ļ Loading model ({smart_config.effective_loading_bits}-bit, for {smart_config.bits}-bit GGUF export)...") else: logger.info(f"đŸ“Ļ Loading model ({smart_config.bits}-bit)...") - + model_kwargs = { "trust_remote_code": trust_remote_code, "torch_dtype": smart_config.dtype, } - + hf_config = None - + # Check if model is already quantized to prevent conflicts try: from transformers import AutoConfig - - hf_config = AutoConfig.from_pretrained( - model_name, trust_remote_code=trust_remote_code - ) - + hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) + existing_quant = getattr(hf_config, "quantization_config", None) if existing_quant: allow_requantize = False - + # Allow 8-bit -> 4-bit re-quantization (if B&B) is_bnb = "BitsAndBytesConfig" in existing_quant.__class__.__name__ is_8bit = getattr(existing_quant, "load_in_8bit", False) - + if is_bnb and is_8bit and smart_config.bits == 4: allow_requantize = True if is_bnb and is_8bit and smart_config.bits == 4: allow_requantize = True if verbose: logger.info(" â„šī¸ Re-quantizing 8-bit model to 4-bit") - + if not allow_requantize: if verbose: - logger.warning( - f"âš ī¸ Model is already quantized ({existing_quant.__class__.__name__}). Disabling dynamic quantization." - ) + logger.warning(f"âš ī¸ Model is already quantized ({existing_quant.__class__.__name__}). Disabling dynamic quantization.") quantize = False - + except Exception: - pass # Ignore config loading errors, proceed with defaults + pass # Ignore config loading errors, proceed with defaults # Apply quantization if requested - if cls._should_apply_quantization( - quantize, smart_config.bits, from_config_only - ): + if cls._should_apply_quantization(quantize, smart_config.bits, from_config_only): model_kwargs.update(cls._get_quantization_kwargs(smart_config)) - + # Device map for memory management if smart_config.cpu_offload: model_kwargs["device_map"] = "auto" model_kwargs["offload_folder"] = "offload" elif torch.cuda.is_available(): model_kwargs["device_map"] = {"": smart_config.device} - + # Load the model with progress spinner with QuantLLMProgress() as p: if verbose: task = p.add_task("Downloading & Loading model...", total=None) - + model = cls._load_model_with_fallback( model_name, model_kwargs, @@ -477,21 +445,21 @@ def from_pretrained( base_model_fallback=base_model_fallback, from_config_only=from_config_only, ) - + if verbose: p.update(task, completed=100) - + # Apply additional optimizations if smart_config.gradient_checkpointing: - if hasattr(model, "gradient_checkpointing_enable"): + if hasattr(model, 'gradient_checkpointing_enable'): model.gradient_checkpointing_enable() if verbose: logger.info(" ✓ Gradient checkpointing enabled") - + # Enable Flash Attention if available if smart_config.use_flash_attention: cls._enable_flash_attention(model, verbose) - + # Compile model if beneficial if smart_config.compile_model: try: @@ -501,16 +469,16 @@ def from_pretrained( except Exception as e: if verbose: print_warning(f"torch.compile failed: {e}") - + if verbose: print_success("Model loaded successfully!") logger.info("") - + instance = cls(model, tokenizer, smart_config, export_push_config=config) instance._is_quantized = quantize and smart_config.bits < 16 - + return instance - + @classmethod def from_gguf( cls, @@ -519,46 +487,45 @@ def from_gguf( *, device: Optional[str] = None, verbose: bool = True, - **kwargs, + **kwargs ) -> "TurboModel": """ Load a GGUF model directly from HuggingFace Hub or local path. - + This uses transformers' native GGUF support (requires transformers>=4.36.0). - + Args: model_id: HuggingFace repo ID (e.g., "TheBloke/Llama-2-7B-GGUF") or local directory - filename: GGUF filename (e.g., "llama-2-7b.Q4_K_M.gguf"). + filename: GGUF filename (e.g., "llama-2-7b.Q4_K_M.gguf"). If None, tries to auto-find. Use list_gguf_files() to see available options. device: Override device (default: auto-detect best GPU) verbose: Print progress **kwargs: Additional args for AutoModelForCausalLM.from_pretrained - + Returns: TurboModel with loaded GGUF model - + Example: >>> # List available GGUF files in a repo >>> files = TurboModel.list_gguf_files("TheBloke/Llama-2-7B-GGUF") >>> print(files) - >>> + >>> >>> # Load specific quantization >>> model = TurboModel.from_gguf( - ... "TheBloke/Llama-2-7B-GGUF", + ... "TheBloke/Llama-2-7B-GGUF", ... filename="llama-2-7b.Q4_K_M.gguf" ... ) - >>> + >>> >>> # Generate text >>> model.generate("Hello!") """ if verbose: print_header(f"Loading GGUF: {model_id}") - + # Check for GGUF package try: import gguf - - gguf_version = getattr(gguf, "__version__", "unknown") + gguf_version = getattr(gguf, '__version__', 'unknown') if verbose: print_info(f"Using gguf version: {gguf_version}") except ImportError: @@ -567,7 +534,7 @@ def from_gguf( "Loading GGUF models requires the 'gguf' package.\n" "Please run: pip install gguf>=0.10.0" ) - + # If no filename specified, try to find one if filename is None: if verbose: @@ -576,33 +543,31 @@ def from_gguf( available_files = cls.list_gguf_files(model_id) if available_files: # Prefer Q4_K_M if available, otherwise take first - q4_files = [f for f in available_files if "q4_k_m" in f.lower()] + q4_files = [f for f in available_files if 'q4_k_m' in f.lower()] filename = q4_files[0] if q4_files else available_files[0] if verbose: - print_info( - f"Found {len(available_files)} GGUF files, using: {filename}" - ) + print_info(f"Found {len(available_files)} GGUF files, using: {filename}") except Exception as e: if verbose: print_warning(f"Could not list GGUF files: {e}") - + if verbose and filename: print_info(f"Loading: {filename}") - + smart_config = SmartConfig.detect(model_id, device=device) smart_config.quant_type = "GGUF" - + with QuantLLMProgress() as progress: if verbose: task = progress.add_task("Loading GGUF model...", total=None) - + try: model = AutoModelForCausalLM.from_pretrained( model_id, gguf_file=filename, torch_dtype=smart_config.dtype, trust_remote_code=True, - **kwargs, + **kwargs ) except ImportError as e: if "gguf" in str(e).lower(): @@ -619,7 +584,7 @@ def from_gguf( f"Available files: {available}" ) from e raise - + # Load tokenizer try: tokenizer = AutoTokenizer.from_pretrained(model_id, gguf_file=filename) @@ -629,34 +594,32 @@ def from_gguf( except Exception: # Some GGUF repos might not have tokenizer, try base model if verbose: - print_warning( - "Could not load tokenizer from GGUF repo, using default" - ) + print_warning("Could not load tokenizer from GGUF repo, using default") tokenizer = None - + if verbose: print_success("GGUF Model loaded successfully!") - + # Print model info - if hasattr(model, "num_parameters"): + if hasattr(model, 'num_parameters'): params = model.num_parameters() / 1e9 print_info(f"Parameters: {params:.2f}B") - + instance = cls(model, tokenizer, smart_config, verbose=verbose) instance._is_quantized = True return instance - + @staticmethod def list_gguf_files(model_id: str) -> List[str]: """ List available GGUF files in a HuggingFace repository. - + Args: model_id: HuggingFace repo ID (e.g., "TheBloke/Llama-2-7B-GGUF") - + Returns: List of GGUF filenames available in the repository - + Example: >>> files = TurboModel.list_gguf_files("TheBloke/Llama-2-7B-GGUF") >>> print(files) @@ -664,60 +627,50 @@ def list_gguf_files(model_id: str) -> List[str]: """ try: from huggingface_hub import list_repo_files - + all_files = list_repo_files(model_id) - gguf_files = [f for f in all_files if f.endswith(".gguf")] - + gguf_files = [f for f in all_files if f.endswith('.gguf')] + # Sort by quantization quality (Q4_K_M before Q2_K, etc.) def quant_sort_key(name): name_lower = name.lower() # Higher number = better quality, listed first - if "f32" in name_lower: - return 0 - if "f16" in name_lower: - return 1 - if "q8" in name_lower: - return 2 - if "q6" in name_lower: - return 3 - if "q5_k_m" in name_lower: - return 4 - if "q5_k_s" in name_lower: - return 5 - if "q4_k_m" in name_lower: - return 6 - if "q4_k_s" in name_lower: - return 7 - if "q3_k" in name_lower: - return 8 - if "q2_k" in name_lower: - return 9 + if 'f32' in name_lower: return 0 + if 'f16' in name_lower: return 1 + if 'q8' in name_lower: return 2 + if 'q6' in name_lower: return 3 + if 'q5_k_m' in name_lower: return 4 + if 'q5_k_s' in name_lower: return 5 + if 'q4_k_m' in name_lower: return 6 + if 'q4_k_s' in name_lower: return 7 + if 'q3_k' in name_lower: return 8 + if 'q2_k' in name_lower: return 9 return 10 - + return sorted(gguf_files, key=quant_sort_key) - + except Exception as e: # If it's a local path, list directory if os.path.isdir(model_id): - return [f for f in os.listdir(model_id) if f.endswith(".gguf")] + return [f for f in os.listdir(model_id) if f.endswith('.gguf')] raise ValueError(f"Could not list GGUF files from {model_id}: {e}") @staticmethod def _get_quantization_kwargs(config: SmartConfig) -> Dict[str, Any]: """ Get kwargs for quantized model loading. - + Note: BitsAndBytes only supports 4-bit and 8-bit quantization for loading. Other bit widths (2, 3, 5, 6) are only available during GGUF export. - + For loading: - bits <= 4: Uses 4-bit NF4 quantization - - bits 5-7: Uses 8-bit quantization + - bits 5-7: Uses 8-bit quantization - bits >= 8: Uses 8-bit quantization """ try: from transformers import BitsAndBytesConfig - + # BitsAndBytes only supports 4-bit and 8-bit # Map requested bits to available options if config.bits <= 4: @@ -730,12 +683,8 @@ def _get_quantization_kwargs(config: SmartConfig) -> Dict[str, Any]: bnb_4bit_use_double_quant=True, ) if config.bits < 4: - logger.info( - f" â„šī¸ BitsAndBytes supports 4/8-bit only. Using 4-bit for requested {config.bits}-bit." - ) - logger.info( - f" Tip: Export to GGUF for Q{config.bits}_K quantization!" - ) + logger.info(f" â„šī¸ BitsAndBytes supports 4/8-bit only. Using 4-bit for requested {config.bits}-bit.") + logger.info(f" Tip: Export to GGUF for Q{config.bits}_K quantization!") else: # 5, 6, 7, 8-bit requests -> 8-bit effective_bits = 8 @@ -743,15 +692,11 @@ def _get_quantization_kwargs(config: SmartConfig) -> Dict[str, Any]: load_in_8bit=True, ) if config.bits != 8: - logger.info( - f" â„šī¸ BitsAndBytes supports 4/8-bit only. Using 8-bit for requested {config.bits}-bit." - ) - logger.info( - f" Tip: Export to GGUF for Q{config.bits}_K quantization!" - ) - + logger.info(f" â„šī¸ BitsAndBytes supports 4/8-bit only. Using 8-bit for requested {config.bits}-bit.") + logger.info(f" Tip: Export to GGUF for Q{config.bits}_K quantization!") + return {"quantization_config": quantization_config} - + except ImportError: logger.warning("⚠ bitsandbytes not installed, loading without quantization") return {} @@ -779,20 +724,20 @@ def _build_export_push_config(config: Optional[Dict[str, Any]]) -> Dict[str, Any resolved["push_quantization"] = resolved["quantization"] return resolved - + @staticmethod def _enable_flash_attention(model: PreTrainedModel, verbose: bool = True) -> None: """Enable Flash Attention if available.""" try: # Try to use native Flash Attention 2 - if hasattr(model, "config"): + if hasattr(model, 'config'): model.config._attn_implementation = "flash_attention_2" if verbose: logger.info(" ✓ Flash Attention 2 enabled") except Exception: if verbose: logger.warning(" ⚠ Flash Attention not available") - + def generate( self, prompt: str, @@ -809,7 +754,7 @@ def generate( ) -> str: """ Generate text from a prompt. - + Args: prompt: Input text prompt max_new_tokens: Maximum tokens to generate @@ -821,18 +766,18 @@ def generate( repetition_penalty: Penalty for repeating tokens (>1.0 = less repetition) stop_strings: List of strings that stop generation **kwargs: Additional generation parameters - + Returns: Generated text response - + Example: >>> response = model.generate("Explain quantum computing") - >>> + >>> >>> # With streaming >>> response = model.generate("Tell me a story", stream=True) """ import sys - + # Tokenize input inputs = self.tokenizer( prompt, @@ -840,27 +785,20 @@ def generate( truncation=True, max_length=self.config.max_seq_length - max_new_tokens, ) - + # Move to device inputs = {k: v.to(self.model.device) for k, v in inputs.items()} - + # Default stop strings stop_strings = stop_strings or [] - + # Streaming generation if stream: return self._generate_streaming( - inputs, - max_new_tokens, - temperature, - top_p, - top_k, - do_sample, - repetition_penalty, - stop_strings, - **kwargs, + inputs, max_new_tokens, temperature, top_p, top_k, + do_sample, repetition_penalty, stop_strings, **kwargs ) - + # Non-streaming generation with torch.inference_mode(): outputs = self.model.generate( @@ -875,19 +813,19 @@ def generate( repetition_penalty=repetition_penalty, **kwargs, ) - + # Decode, removing the prompt - generated = outputs[0][inputs["input_ids"].shape[1] :] + generated = outputs[0][inputs["input_ids"].shape[1]:] response = self.tokenizer.decode(generated, skip_special_tokens=True) - + # Check for stop strings and truncate for stop in stop_strings: if stop in response: response = response.split(stop)[0] break - + return response.strip() - + def _generate_streaming( self, inputs: Dict, @@ -902,18 +840,17 @@ def _generate_streaming( ) -> str: """Generate with streaming output.""" import sys - + try: - from threading import Thread - from transformers import TextIteratorStreamer - + from threading import Thread + streamer = TextIteratorStreamer( - self.tokenizer, + self.tokenizer, skip_prompt=True, skip_special_tokens=True, ) - + generation_kwargs = { **inputs, "max_new_tokens": max_new_tokens, @@ -927,18 +864,18 @@ def _generate_streaming( "streamer": streamer, **kwargs, } - + # Run generation in background thread thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() - + # Stream output generated_text = [] for new_text in streamer: sys.stdout.write(new_text) sys.stdout.flush() generated_text.append(new_text) - + # Check stop strings full_text = "".join(generated_text) should_stop = False @@ -948,17 +885,17 @@ def _generate_streaming( break if should_stop: break - + thread.join() print() # New line after streaming - + response = "".join(generated_text) for stop in stop_strings: if stop in response: response = response.split(stop)[0] - + return response.strip() - + except ImportError: # Fallback to non-streaming print("(Streaming not available, using batch generation)") @@ -968,7 +905,7 @@ def _generate_streaming( temperature=temperature, stream=False, ) - + def chat( self, messages: List[Dict[str, str]], @@ -976,14 +913,14 @@ def chat( ) -> str: """ Generate response for chat-format messages. - + Args: messages: List of {"role": "user/assistant/system", "content": "..."} **kwargs: Additional generation parameters - + Returns: Assistant's response - + Example: >>> response = model.chat([ ... {"role": "system", "content": "You are a helpful assistant."}, @@ -992,12 +929,9 @@ def chat( """ # Try to apply chat template prompt = None - + # First, try native chat template - if ( - hasattr(self.tokenizer, "chat_template") - and self.tokenizer.chat_template is not None - ): + if hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template is not None: try: prompt = self.tokenizer.apply_chat_template( messages, @@ -1006,25 +940,25 @@ def chat( ) except Exception: pass - + # If no template, use a sensible default if prompt is None: # Default chat format (works for most models) parts = [] for m in messages: - role = m.get("role", "user") - content = m.get("content", "") - if role == "system": + role = m.get('role', 'user') + content = m.get('content', '') + if role == 'system': parts.append(f"System: {content}\n") - elif role == "user": + elif role == 'user': parts.append(f"User: {content}\n") - elif role == "assistant": + elif role == 'assistant': parts.append(f"Assistant: {content}\n") parts.append("Assistant:") prompt = "".join(parts) - + return self.generate(prompt, **kwargs) - + def finetune( self, data: Union[str, List[Dict[str, str]], Any], @@ -1040,7 +974,7 @@ def finetune( ) -> Dict[str, Any]: """ Fine-tune the model with LoRA. - + Args: data: Training data - file path, list of dicts, or HF dataset epochs: Number of training epochs (default: auto) @@ -1051,7 +985,7 @@ def finetune( output_dir: Where to save (default: ./output/{model_name}) hub_manager: QuantLLMHubManager instance for auto-tracking **kwargs: Additional training arguments - + Returns: Training results dictionary """ @@ -1061,40 +995,33 @@ def finetune( os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" if "WANDB_DISABLED" not in os.environ: os.environ["WANDB_DISABLED"] = "true" - + # Suppress noise import warnings - warnings.filterwarnings("ignore", module="peft") warnings.filterwarnings("ignore", category=FutureWarning) - + try: from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training - from transformers import ( - DataCollatorForLanguageModeling, - Trainer, - TrainingArguments, - ) + from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling except ImportError: - raise ImportError( - "peft is required for fine-tuning. Install with: pip install peft" - ) - + raise ImportError("peft is required for fine-tuning. Install with: pip install peft") + # 2. Prepare model for training if self._is_quantized: self.model = prepare_model_for_kbit_training(self.model) - + # Enable gradient checkpointing if configured if self.config.gradient_checkpointing: self.model.gradient_checkpointing_enable() - + # 3. Auto-configure LoRA r = lora_r or (16 if self.model.num_parameters() < 10e9 else 64) alpha = lora_alpha or (r * 2) - + # Determine target modules based on model type target_modules = self._get_lora_target_modules() - + lora_config = LoraConfig( r=r, lora_alpha=alpha, @@ -1103,43 +1030,37 @@ def finetune( bias="none", task_type="CAUSAL_LM", ) - + # Apply LoRA if not already applied if not self._lora_applied: self.model = get_peft_model(self.model, lora_config) self._lora_applied = True - + trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad) total = sum(p.numel() for p in self.model.parameters()) - print_info( - f"LoRA applied: {trainable:,} trainable params ({100*trainable/total:.2f}%)" - ) - + print_info(f"LoRA applied: {trainable:,} trainable params ({100*trainable/total:.2f}%)") + # 4. Load and prepare data train_dataset = self._prepare_dataset(data) - + # 5. Auto-configure training settings epochs = epochs or 3 batch_size = batch_size or self.config.batch_size learning_rate = learning_rate or 2e-4 - output_dir = ( - output_dir or f"./output/{self.model.config._name_or_path.split('/')[-1]}" - ) - + output_dir = output_dir or f"./output/{self.model.config._name_or_path.split('/')[-1]}" + # Auto-track parameters if hub_manager provided if hub_manager: - hub_manager.track_hyperparameters( - { - "epochs": epochs, - "learning_rate": learning_rate, - "batch_size": batch_size, - "lora_r": r, - "lora_alpha": alpha, - "base_model": getattr(self.config, "model_name", "unknown"), - "output_dir": output_dir, - } - ) - + hub_manager.track_hyperparameters({ + "epochs": epochs, + "learning_rate": learning_rate, + "batch_size": batch_size, + "lora_r": r, + "lora_alpha": alpha, + "base_model": getattr(self.config, "model_name", "unknown"), + "output_dir": output_dir + }) + # Training loop try: training_args = TrainingArguments( @@ -1159,138 +1080,99 @@ def finetune( torch_compile=self.config.compile_model, **kwargs, ) - + trainer = Trainer( model=self.model, args=training_args, train_dataset=train_dataset, - tokenizer=self.tokenizer, # Use new argument name - data_collator=DataCollatorForLanguageModeling( - self.tokenizer, mlm=False - ), + tokenizer=self.tokenizer, # Use new argument name + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), ) - + result = trainer.train() self._is_finetuned = True - + print_success(f"Training complete! Model saved to {output_dir}") - + return { "train_loss": result.training_loss, "epochs": epochs, "output_dir": output_dir, "learning_rate": learning_rate, "batch_size": batch_size, - "lora_r": r, + "lora_r": r } - + except Exception as e: print_error(f"Training failed: {e}") raise # Hint about OOM if "out of memory" in str(e).lower(): - print_info( - "Tip: Try reducing batch_size or enabling gradient_checkpointing in config." - ) + print_info("Tip: Try reducing batch_size or enabling gradient_checkpointing in config.") raise - + def _get_lora_target_modules(self) -> List[str]: """Get appropriate LoRA target modules for the model.""" - model_type = getattr(self.model.config, "model_type", "").lower() - + model_type = getattr(self.model.config, 'model_type', '').lower() + # Common patterns by model type target_patterns = { - "llama": [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ], - "mistral": [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ], - "qwen2": [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ], + "llama": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + "mistral": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + "qwen2": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], "phi": ["q_proj", "k_proj", "v_proj", "dense", "fc1", "fc2"], - "gemma": [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ], + "gemma": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], } - + return target_patterns.get(model_type, ["q_proj", "v_proj"]) - + def _prepare_dataset(self, data: Union[str, List[Dict], Any]) -> Any: """Prepare dataset for training.""" from datasets import Dataset, load_dataset - + if isinstance(data, str): # Load from file - if data.endswith(".json") or data.endswith(".jsonl"): - dataset = load_dataset("json", data_files=data)["train"] + if data.endswith('.json') or data.endswith('.jsonl'): + dataset = load_dataset('json', data_files=data)['train'] else: - dataset = load_dataset(data)["train"] + dataset = load_dataset(data)['train'] elif isinstance(data, list): dataset = Dataset.from_list(data) else: dataset = data # Assume it's already a Dataset - + # Tokenize def tokenize_function(examples): # Handle different data formats - if "text" in examples: - texts = examples["text"] - elif "instruction" in examples and "output" in examples: + if 'text' in examples: + texts = examples['text'] + elif 'instruction' in examples and 'output' in examples: texts = [ f"### Instruction:\n{inst}\n\n### Response:\n{out}" - for inst, out in zip(examples["instruction"], examples["output"]) + for inst, out in zip(examples['instruction'], examples['output']) ] else: # Try to concatenate all string fields keys = [k for k in examples.keys() if isinstance(examples[k][0], str)] - texts = [ - " ".join([examples[k][i] for k in keys]) - for i in range(len(examples[keys[0]])) - ] - + texts = [' '.join([examples[k][i] for k in keys]) for i in range(len(examples[keys[0]]))] + return self.tokenizer( texts, truncation=True, max_length=self.config.max_seq_length, - padding=False, # Use DataCollator for dynamic padding + padding=False, # Use DataCollator for dynamic padding ) - + tokenized = dataset.map( tokenize_function, batched=True, remove_columns=dataset.column_names, - load_from_cache_file=False, # Avoid hash warnings + load_from_cache_file=False, # Avoid hash warnings desc="Tokenizing dataset", ) - + return tokenized - + def export( self, format: Optional[str] = None, @@ -1301,13 +1183,13 @@ def export( ) -> str: """ Export model to various formats. - + Supported formats: - "gguf": For llama.cpp, Ollama, LM Studio (Q4_K_M, Q5_K_M, etc.) - "safetensors": For HuggingFace ecosystem - "onnx": For ONNX Runtime, TensorRT - "mlx": For Apple Silicon Macs - + Args: format: Target format (gguf, safetensors, onnx, mlx). Uses shared config when omitted. output_path: Output file/directory path @@ -1316,10 +1198,10 @@ def export( - ONNX: dynamic, int8 - MLX: 4bit, 8bit **kwargs: Format-specific options - + Returns: Path to exported model - + Example: >>> model.export("gguf") # Uses auto name >>> model.export("gguf", "my_model.gguf", quantization="Q4_K_M") @@ -1329,26 +1211,24 @@ def export( format = ( format if format is not None - else self.export_push_config.get( - "format", DEFAULT_EXPORT_PUSH_CONFIG["format"] - ) + else self.export_push_config.get("format", DEFAULT_EXPORT_PUSH_CONFIG["format"]) ).lower() effective_quantization = quantization if effective_quantization is None and format == "gguf": effective_quantization = self.export_push_config.get( "quantization", DEFAULT_EXPORT_PUSH_CONFIG["quantization"] ) - + # Merge LoRA if applied if self._lora_applied: if self.verbose: print_info("Merging LoRA weights before export...") self.model = self.model.merge_and_unload() self._lora_applied = False - + # Auto-generate output path if output_path is None: - model_name = self.model.config._name_or_path.split("/")[-1] + model_name = self.model.config._name_or_path.split('/')[-1] if format == "gguf": quant = effective_quantization output_path = f"{model_name}.{quant.upper()}.gguf" @@ -1360,7 +1240,7 @@ def export( output_path = f"./{model_name}-mlx/" else: output_path = f"./{model_name}-{format}/" - + exporters = { "gguf": self._export_gguf, "safetensors": self._export_safetensors, @@ -1368,16 +1248,12 @@ def export( "mlx": self._export_mlx, } if format not in exporters: - raise ValueError( - f"Unknown format: {format}. Supported: {list(exporters.keys())}" - ) - + raise ValueError(f"Unknown format: {format}. Supported: {list(exporters.keys())}") + print_header(f"Exporting to {format.upper()}") - result = exporters[format]( - output_path, quantization=effective_quantization, **kwargs - ) + result = exporters[format](output_path, quantization=effective_quantization, **kwargs) print_success(f"Exported to: {result}") - + return result def push_to_hub( @@ -1388,11 +1264,11 @@ def push_to_hub( quantization: Optional[str] = None, commit_message: str = "Upload model via QuantLLM", license: str = "apache-2.0", - **kwargs, + **kwargs ): """ Push model to HuggingFace Hub with proper model card. - + Args: repo_id: Repository ID (e.g. "username/model") token: HF Token @@ -1401,41 +1277,39 @@ def push_to_hub( commit_message: Commit message license: License type (default: apache-2.0) **kwargs: Arguments for export - + Supported formats: - safetensors: Standard HuggingFace format - gguf: For llama.cpp, Ollama, LM Studio - onnx: For ONNX Runtime, TensorRT - mlx: For Apple Silicon (requires macOS) - + The model card will be automatically generated with: - Proper YAML frontmatter for HuggingFace - Format-specific usage examples - "Use this model" button compatibility """ from ..hub import QuantLLMHubManager - + format_lower = ( format if format is not None - else self.export_push_config.get( - "push_format", DEFAULT_EXPORT_PUSH_CONFIG["push_format"] - ) + else self.export_push_config.get("push_format", DEFAULT_EXPORT_PUSH_CONFIG["push_format"]) ).lower() push_quantization = quantization or self.export_push_config.get( "push_quantization", DEFAULT_EXPORT_PUSH_CONFIG["push_quantization"] ) - + # Get the original base model name (full path for HuggingFace link) base_model_full = self.model.config._name_or_path - model_name = base_model_full.split("/")[-1] - + model_name = base_model_full.split('/')[-1] + print_header(f"Pushing to {repo_id}") print_info(f"Format: {format_lower.upper()}") print_info(f"Base model: {base_model_full}") - + manager = QuantLLMHubManager(repo_id=repo_id, hf_token=token) - + if format_lower == "gguf": # Export GGUF directly to staging quant_label = push_quantization or self.export_push_config.get( @@ -1443,75 +1317,65 @@ def push_to_hub( ) filename = f"{model_name}.{quant_label.upper()}.gguf" save_path = os.path.join(manager.staging_dir, filename) - - self.export( - format="gguf", output_path=save_path, quantization=quant_label, **kwargs - ) - - manager.track_hyperparameters( - { - "format": "gguf", - "quantization": quant_label.upper(), - "base_model": base_model_full, - "license": license, - } - ) + + self.export(format="gguf", output_path=save_path, quantization=quant_label, **kwargs) + + manager.track_hyperparameters({ + "format": "gguf", + "quantization": quant_label.upper(), + "base_model": base_model_full, + "license": license, + }) manager._generate_model_card(format="gguf") - + elif format_lower == "onnx": # Export to ONNX format print_info("Exporting to ONNX format...") save_path = manager.staging_dir - + self._export_onnx(save_path, quantization=push_quantization, **kwargs) - - manager.track_hyperparameters( - { - "format": "onnx", - "quantization": push_quantization, - "base_model": base_model_full, - "license": license, - } - ) + + manager.track_hyperparameters({ + "format": "onnx", + "quantization": push_quantization, + "base_model": base_model_full, + "license": license, + }) manager._generate_model_card(format="onnx") - + elif format_lower == "mlx": # Export to MLX format print_info("Exporting to MLX format...") save_path = manager.staging_dir - + self._export_mlx(save_path, quantization=push_quantization, **kwargs) - - manager.track_hyperparameters( - { - "format": "mlx", - "quantization": push_quantization, - "base_model": base_model_full, - "license": license, - } - ) + + manager.track_hyperparameters({ + "format": "mlx", + "quantization": push_quantization, + "base_model": base_model_full, + "license": license, + }) manager._generate_model_card(format="mlx") - + else: # SafeTensors or PyTorch format - manager.track_hyperparameters( - { - "format": format_lower, - "base_model": base_model_full, - "license": license, - } - ) + manager.track_hyperparameters({ + "format": format_lower, + "base_model": base_model_full, + "license": license, + }) manager.save_final_model(self, format=format_lower) manager._generate_model_card(format=format_lower) - + manager.push(commit_message=commit_message) - + # Alias for convenience push = push_to_hub - + def _export_gguf( - self, - output_path: str, + self, + output_path: str, quantization: Optional[str] = None, fast_mode: bool = False, chunked_conversion: bool = False, @@ -1519,19 +1383,19 @@ def _export_gguf( smart_tensor_ordering: bool = False, disk_offloading: bool = False, disk_offload_dir: Optional[str] = None, - **kwargs, + **kwargs ) -> str: """ Export to GGUF format using optimized llama.cpp converter. - + Automatically installs and configures llama.cpp tools. Handles BitsAndBytes quantized models by dequantizing first. - + Flow: 1. Save model to temp directory (dequantize if needed) 2. Convert to F16 GGUF using convert_hf_to_gguf.py 3. Quantize to target format (Q4_K_M, Q5_K_M, etc.) using llama-quantize - + Args: output_path: Output file path for GGUF quantization: Quantization type (Q4_K_M, Q5_K_M, Q8_0, etc.) @@ -1542,87 +1406,71 @@ def _export_gguf( disk_offloading: Use a dedicated temp/offload directory for intermediate artifacts disk_offload_dir: Directory used when disk_offloading=True """ + from ..quant import convert_to_gguf, quantize_gguf, ensure_llama_cpp_installed, GGUF_QUANT_TYPES + from ..utils import QuantLLMProgress, format_time, format_size import time - - from ..quant import ( - GGUF_QUANT_TYPES, - convert_to_gguf, - ensure_llama_cpp_installed, - quantize_gguf, - ) - from ..utils import QuantLLMProgress, format_size, format_time - + start_time = time.time() - + effective_shard_size = max_shard_size or ( DEFAULT_CHUNKED_SHARD_SIZE if chunked_conversion else None ) - + quant_type = quantization or self.config.quant_type or "q4_k_m" quant_type_upper = quant_type.upper() quant_type_lower = quant_type.lower() - + # Check if this is a passthrough format (f16, bf16, f32 - no quantization needed) - passthrough_types = {"f16", "f32", "bf16", "float16", "float32", "bfloat16"} + passthrough_types = {'f16', 'f32', 'bf16', 'float16', 'float32', 'bfloat16'} needs_quantization = quant_type_lower not in passthrough_types - + if self.verbose: print_info(f"Target quantization: {quant_type_upper}") if fast_mode: print_info("Fast mode enabled") if chunked_conversion: - print_info( - f"Chunked conversion enabled (max_shard_size={effective_shard_size})" - ) + print_info(f"Chunked conversion enabled (max_shard_size={effective_shard_size})") if smart_tensor_ordering: print_info("Smart tensor ordering enabled") - print_warning( - "Smart tensor ordering may temporarily materialize a full state dict in memory." - ) + print_warning("Smart tensor ordering may temporarily materialize a full state dict in memory.") if disk_offloading: - print_info( - f"Disk offloading enabled ({disk_offload_dir or 'system temp'})" - ) - + print_info(f"Disk offloading enabled ({disk_offload_dir or 'system temp'})") + # Ensure llama.cpp if self.verbose: print_info("Checking llama.cpp installation...") ensure_llama_cpp_installed() - + # Check if model is BitsAndBytes quantized and needs dequantization model_to_save = self.model is_bnb_quantized = self._is_bnb_quantized() - + if is_bnb_quantized: if self.verbose: - print_warning( - "Model is BitsAndBytes quantized. Dequantizing for GGUF export..." - ) - print_info( - "This may use significant memory. For large models, consider loading with quantize=False." - ) - + print_warning("Model is BitsAndBytes quantized. Dequantizing for GGUF export...") + print_info("This may use significant memory. For large models, consider loading with quantize=False.") + model_to_save = self._dequantize_model() if self.verbose: print_success("Model dequantized successfully!") - + # Determine dtype for initial conversion (always F16 for best quality) model_dtype = "f16" - + # Get model name for file naming - model_name = self.model.config._name_or_path.split("/")[-1] - + model_name = self.model.config._name_or_path.split('/')[-1] + temp_parent = disk_offload_dir if disk_offloading else None if temp_parent: os.makedirs(temp_parent, exist_ok=True) - + # Create temp dir for conversion with tempfile.TemporaryDirectory(dir=temp_parent) as temp_dir: # Step 1: Save model to temp directory if self.verbose: print_header("Step 1/3: Saving Model", icon="💾") print_info(f"Staging model to {temp_dir}...") - + with QuantLLMProgress() as progress: task = progress.add_task("Saving model weights...", total=None) save_kwargs = { @@ -1630,101 +1478,95 @@ def _export_gguf( } if effective_shard_size: save_kwargs["max_shard_size"] = effective_shard_size - + if smart_tensor_ordering: - save_kwargs["state_dict"] = memory_optimized_tensor_order( - model_to_save.state_dict() - ) - + save_kwargs["state_dict"] = memory_optimized_tensor_order(model_to_save.state_dict()) + try: model_to_save.save_pretrained(temp_dir, **save_kwargs) except Exception as e: if self.verbose: - print_warning( - f"SafeTensors save failed ({e}), using PyTorch format..." - ) + print_warning(f"SafeTensors save failed ({e}), using PyTorch format...") save_kwargs["safe_serialization"] = False model_to_save.save_pretrained(temp_dir, **save_kwargs) - + self.tokenizer.save_pretrained(temp_dir) progress.update(task, completed=100) - + if self.verbose: print_success("Model saved to staging area!") - + # Step 2: Convert to F16 GGUF if self.verbose: print_header("Step 2/3: Converting to GGUF", icon="🔄") - + # F16 intermediate file (or final if no quantization needed) if needs_quantization: f16_gguf_file = os.path.join(temp_dir, f"{model_name}.F16.gguf") else: f16_gguf_file = f"{model_name}.{quant_type_upper}.gguf" - + output_files, _ = convert_to_gguf( model_name=model_name, input_folder=temp_dir, model_dtype=model_dtype, quantization_type="f16" if needs_quantization else quant_type_lower, - print_output=self.verbose, + print_output=self.verbose ) - + if not output_files: raise RuntimeError("GGUF conversion failed to produce output file.") - + f16_file = output_files[0] - + # If conversion produced a different name, use that if os.path.exists(f16_file): f16_gguf_file = f16_file - + if self.verbose: print_success(f"F16 GGUF created: {f16_gguf_file}") - + # Step 3: Apply quantization if needed if needs_quantization: if self.verbose: - print_header( - f"Step 3/3: Quantizing to {quant_type_upper}", icon="⚡" - ) + print_header(f"Step 3/3: Quantizing to {quant_type_upper}", icon="⚡") print_info(f"Applying {quant_type_upper} quantization...") - + # Final quantized output quantized_file = f"{model_name}.{quant_type_upper}.gguf" - + quantize_gguf( input_gguf=f16_gguf_file, output_gguf=quantized_file, quant_type=quant_type_upper, - print_output=self.verbose, + print_output=self.verbose ) - + final_file = quantized_file - + # Clean up intermediate F16 file if os.path.exists(f16_gguf_file) and f16_gguf_file != quantized_file: os.remove(f16_gguf_file) - + if self.verbose: print_success(f"Quantization complete: {quantized_file}") else: final_file = f16_gguf_file if self.verbose: print_info("No quantization needed (already in target format)") - + # Move to output path if different if os.path.abspath(final_file) != os.path.abspath(output_path): if self.verbose: print_info(f"Moving {final_file} → {output_path}") shutil.move(final_file, output_path) - + # Clean up dequantized model if created if is_bnb_quantized and model_to_save is not self.model: del model_to_save if torch.cuda.is_available(): torch.cuda.empty_cache() - + # Print final summary elapsed = time.time() - start_time if self.verbose: @@ -1734,60 +1576,55 @@ def _export_gguf( print_info(f"Format: GGUF {quant_type_upper}") print_info(f"Size: {format_size(file_size_bytes)}") print_info(f"Time: {format_time(elapsed)}") - + return output_path - + def _is_bnb_quantized(self) -> bool: """Check if model is BitsAndBytes quantized.""" # Check config for quantization_config - if hasattr(self.model, "config"): - quant_config = getattr(self.model.config, "quantization_config", None) + if hasattr(self.model, 'config'): + quant_config = getattr(self.model.config, 'quantization_config', None) if quant_config: # Check if it's BitsAndBytes - quant_method = getattr(quant_config, "quant_method", None) - if quant_method in ["bitsandbytes", "bnb"]: + quant_method = getattr(quant_config, 'quant_method', None) + if quant_method in ['bitsandbytes', 'bnb']: return True - if getattr(quant_config, "load_in_4bit", False): + if getattr(quant_config, 'load_in_4bit', False): return True - if getattr(quant_config, "load_in_8bit", False): + if getattr(quant_config, 'load_in_8bit', False): return True - + # Check for BNB linear layers in the model try: import bitsandbytes as bnb - for module in self.model.modules(): if isinstance(module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): return True except ImportError: pass - + return False - + def _dequantize_model(self) -> nn.Module: """ Dequantize a BitsAndBytes model to full precision for GGUF export. - + Returns: Dequantized model in float16/bfloat16 """ import gc - + # Get the model name for reloading - model_name = getattr(self.model.config, "_name_or_path", None) - + model_name = getattr(self.model.config, '_name_or_path', None) + if model_name: # Best approach: Reload model in full precision if self.verbose: print_info(f"Reloading {model_name} in full precision...") - + # Determine target dtype - target_dtype = ( - self.config.dtype - if self.config.dtype in [torch.float16, torch.bfloat16] - else torch.float16 - ) - + target_dtype = self.config.dtype if self.config.dtype in [torch.float16, torch.bfloat16] else torch.float16 + try: dequant_model = AutoModelForCausalLM.from_pretrained( model_name, @@ -1801,76 +1638,60 @@ def _dequantize_model(self) -> nn.Module: if self.verbose: print_warning(f"Failed to reload model: {e}") print_info("Attempting in-place dequantization...") - + # Fallback: In-place dequantization (less reliable but works for some models) try: import bitsandbytes as bnb - - target_dtype = ( - self.config.dtype - if self.config.dtype in [torch.float16, torch.bfloat16] - else torch.float16 - ) - + + target_dtype = self.config.dtype if self.config.dtype in [torch.float16, torch.bfloat16] else torch.float16 + # Create a copy of the model state dict with dequantized weights dequant_model = AutoModelForCausalLM.from_config( self.model.config, torch_dtype=target_dtype, ) - + # Copy and dequantize weights with torch.no_grad(): for name, module in self.model.named_modules(): if isinstance(module, bnb.nn.Linear4bit): # Dequantize 4-bit weights target_module = dict(dequant_model.named_modules()).get(name) - if target_module is not None and hasattr( - target_module, "weight" - ): + if target_module is not None and hasattr(target_module, 'weight'): # Get dequantized weight weight = module.weight - if hasattr(weight, "dequantize"): + if hasattr(weight, 'dequantize'): dequant_weight = weight.dequantize() else: # Manual dequantization for older versions dequant_weight = bnb.functional.dequantize_4bit( weight.data, weight.quant_state ) - target_module.weight.data.copy_( - dequant_weight.to(target_dtype) - ) - + target_module.weight.data.copy_(dequant_weight.to(target_dtype)) + if module.bias is not None: - target_module.bias.data.copy_( - module.bias.data.to(target_dtype) - ) - + target_module.bias.data.copy_(module.bias.data.to(target_dtype)) + elif isinstance(module, bnb.nn.Linear8bitLt): # Dequantize 8-bit weights target_module = dict(dequant_model.named_modules()).get(name) - if target_module is not None and hasattr( - target_module, "weight" - ): + if target_module is not None and hasattr(target_module, 'weight'): weight = module.weight - if hasattr(weight, "dequantize"): + if hasattr(weight, 'dequantize'): dequant_weight = weight.dequantize() else: dequant_weight = weight.data.to(target_dtype) - target_module.weight.data.copy_( - dequant_weight.to(target_dtype) - ) - + target_module.weight.data.copy_(dequant_weight.to(target_dtype)) + if module.bias is not None: - target_module.bias.data.copy_( - module.bias.data.to(target_dtype) - ) - + target_module.bias.data.copy_(module.bias.data.to(target_dtype)) + gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() - + return dequant_model - + except Exception as e: raise RuntimeError( f"Failed to dequantize BitsAndBytes model: {e}\n\n" @@ -1878,42 +1699,45 @@ def _dequantize_model(self) -> nn.Module: " model = TurboModel.from_pretrained('your-model', quantize=False)\n" " model.export('gguf', quantization='Q4_K_M')" ) - - def _export_safetensors(self, output_path: str, **kwargs) -> str: + + def _export_safetensors( + self, + output_path: str, + **kwargs + ) -> str: """Export to safetensors format.""" os.makedirs(output_path, exist_ok=True) self.model.save_pretrained(output_path, safe_serialization=True) self.tokenizer.save_pretrained(output_path) return output_path - + def _export_onnx( self, output_path: str, quantization: Optional[str] = None, opset_version: int = 14, - **kwargs, + **kwargs ) -> str: """ Export to ONNX format with proper structure. - + Uses Optimum's ONNX exporter which properly handles LLMs like Llama. torch.onnx.export does NOT work for modern LLMs due to dynamic attention. - + Args: output_path: Output directory for ONNX files quantization: ONNX quantization type (dynamic, static, int8, avx2, avx512) opset_version: ONNX opset version (default: 14) """ from ..utils import QuantLLMProgress - + # Check for required dependencies try: from optimum.onnxruntime import ORTModelForCausalLM - HAS_OPTIMUM = True except ImportError: HAS_OPTIMUM = False - + if not HAS_OPTIMUM: # Cannot export LLMs without Optimum - torch.onnx.export doesn't work error_msg = """ @@ -1931,27 +1755,23 @@ def _export_onnx( pip install quantllm[onnx] """ print_error(error_msg) - raise ImportError( - "ONNX export requires: pip install onnx onnxruntime optimum[onnxruntime] onnxscript" - ) - + raise ImportError("ONNX export requires: pip install onnx onnxruntime optimum[onnxruntime] onnxscript") + os.makedirs(output_path, exist_ok=True) model_name = self.model.config._name_or_path - + if self.verbose: print_info("Using Optimum for ONNX export...") - + with QuantLLMProgress() as progress: task = progress.add_task("Exporting to ONNX...", total=None) - + try: # Check if model is quantized - need to export from original if self._is_bnb_quantized(): if self.verbose: - print_info( - "BNB quantized model detected. Exporting from original HuggingFace model..." - ) - + print_info("BNB quantized model detected. Exporting from original HuggingFace model...") + # Export directly from HuggingFace (not our quantized version) ort_model = ORTModelForCausalLM.from_pretrained( model_name, @@ -1962,11 +1782,11 @@ def _export_onnx( # Save model first, then convert temp_path = os.path.join(output_path, "_temp_hf") os.makedirs(temp_path, exist_ok=True) - + try: self.model.save_pretrained(temp_path, safe_serialization=True) self.tokenizer.save_pretrained(temp_path) - + ort_model = ORTModelForCausalLM.from_pretrained( temp_path, export=True, @@ -1975,110 +1795,82 @@ def _export_onnx( finally: # Clean temp shutil.rmtree(temp_path, ignore_errors=True) - + # Save ONNX model ort_model.save_pretrained(output_path) self.tokenizer.save_pretrained(output_path) - + except Exception as e: progress.update(task, completed=100) error_str = str(e) - + # Check for common issues and provide helpful messages if "onnxscript" in error_str.lower(): - print_error( - "Missing onnxscript package. Install with: pip install onnxscript" - ) - raise ImportError( - "ONNX export requires onnxscript: pip install onnxscript" - ) from e - elif ( - "cannot export" in error_str.lower() - or "unsupported" in error_str.lower() - ): - print_error( - f"Model architecture may not support ONNX export: {error_str}" - ) + print_error("Missing onnxscript package. Install with: pip install onnxscript") + raise ImportError("ONNX export requires onnxscript: pip install onnxscript") from e + elif "cannot export" in error_str.lower() or "unsupported" in error_str.lower(): + print_error(f"Model architecture may not support ONNX export: {error_str}") raise else: raise - + progress.update(task, completed=100) - + # Apply quantization if requested if quantization: if self.verbose: print_info(f"Applying {quantization} ONNX quantization...") self._quantize_onnx_model(output_path, quantization) - + if self.verbose: print_success(f"ONNX model exported to {output_path}") - + return output_path - + def _quantize_onnx_model(self, model_path: str, quant_type: str) -> None: """ Apply ONNX quantization. - + ONNX supports INT8 (8-bit integer) quantization only. Unlike GGUF, ONNX doesn't support 2/3/4/5/6-bit quantization. - + Args: model_path: Path to ONNX model directory quant_type: Quantization type: - Bit-based: "8", "8bit", "int8" → INT8 quantization - Platform: "avx2", "avx512", "arm64" → Platform-optimized INT8 - Type: "dynamic", "static" → Quantization method - + Note: Requests for 4-bit or other bit widths will use INT8 with a warning. """ try: from optimum.onnxruntime import ORTQuantizer from optimum.onnxruntime.configuration import AutoQuantizationConfig - + quantizer = ORTQuantizer.from_pretrained(model_path) - + # Normalize quantization type quant_lower = quant_type.lower().replace("_", "").replace("-", "") - + # Check for bit-based requests (ONNX only supports 8-bit) bit_request = None - for bit_pattern in [ - "2bit", - "3bit", - "4bit", - "5bit", - "6bit", - "q2", - "q3", - "q4", - "q5", - "q6", - ]: + for bit_pattern in ["2bit", "3bit", "4bit", "5bit", "6bit", "q2", "q3", "q4", "q5", "q6"]: if bit_pattern in quant_lower: bit_request = bit_pattern break - + if bit_request: - print_warning( - f"ONNX only supports INT8 (8-bit) quantization, not {quant_type}." - ) + print_warning(f"ONNX only supports INT8 (8-bit) quantization, not {quant_type}.") print_info("For lower bit quantization, use GGUF format instead.") print_info("Proceeding with INT8 quantization...") - + # Determine optimal config based on platform or explicit request if "avx512" in quant_lower or "vnni" in quant_lower: - qconfig = AutoQuantizationConfig.avx512_vnni( - is_static=False, per_channel=True - ) + qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=True) if self.verbose: - print_info( - "Using AVX512 VNNI INT8 quantization (Intel Xeon/Ice Lake+)" - ) + print_info("Using AVX512 VNNI INT8 quantization (Intel Xeon/Ice Lake+)") elif "arm64" in quant_lower or "arm" in quant_lower: - qconfig = AutoQuantizationConfig.arm64( - is_static=False, per_channel=True - ) + qconfig = AutoQuantizationConfig.arm64(is_static=False, per_channel=True) if self.verbose: print_info("Using ARM64 INT8 quantization (Apple Silicon/ARM)") elif "static" in quant_lower: @@ -2091,72 +1883,69 @@ def _quantize_onnx_model(self, model_path: str, quant_type: str) -> None: qconfig = AutoQuantizationConfig.avx2(is_static=False, per_channel=True) if self.verbose: print_info("Using dynamic INT8 quantization (AVX2)") - + # Apply quantization quantizer.quantize(save_dir=model_path, quantization_config=qconfig) - + if self.verbose: print_success("ONNX INT8 quantization applied successfully") - + except ImportError: - print_warning( - "Optimum quantization not available. Skipping ONNX quantization." - ) + print_warning("Optimum quantization not available. Skipping ONNX quantization.") print_info("Install with: pip install optimum[onnxruntime]") except Exception as e: print_warning(f"ONNX quantization failed: {e}") print_info("The unquantized ONNX model is still available.") - + def _export_mlx( - self, output_path: str, quantization: Optional[str] = None, **kwargs + self, + output_path: str, + quantization: Optional[str] = None, + **kwargs ) -> str: """ Export to MLX format for Apple Silicon. - + MLX supports 4-bit and 8-bit quantization only. - + Args: output_path: Output directory quantization: MLX quantization options: - "4bit", "4", "Q4", "Q4_K_M" → 4-bit quantization - "8bit", "8", "Q8" → 8-bit quantization - None → No quantization (FP16) - + Note: 2-bit, 3-bit, 5-bit, 6-bit requests will map to closest (4 or 8-bit). """ - # Check platform - import platform + from ..utils import QuantLLMProgress import subprocess import sys - - from ..utils import QuantLLMProgress - + + # Check platform + import platform if platform.system() != "Darwin" or platform.machine() != "arm64": print_warning("MLX export is optimized for Apple Silicon Macs.") - print_info( - "The model will be saved but may not run efficiently on this system." - ) - + print_info("The model will be saved but may not run efficiently on this system.") + try: import mlx - HAS_MLX = True except ImportError: HAS_MLX = False - + os.makedirs(output_path, exist_ok=True) model_name = self.model.config._name_or_path - + if HAS_MLX: try: from mlx_lm import convert - + if self.verbose: print_info("Using mlx-lm for conversion...") - + with QuantLLMProgress() as progress: task = progress.add_task("Converting to MLX...", total=None) - + # Save HF model first if quantized if self._is_bnb_quantized(): # Use original model name @@ -2166,31 +1955,25 @@ def _export_mlx( os.makedirs(source_path, exist_ok=True) self.model.save_pretrained(source_path) self.tokenizer.save_pretrained(source_path) - + # Build convert arguments convert_args = { "hf_path": source_path, "mlx_path": output_path, } - + # Parse quantization request if quantization: - quant_lower = ( - quantization.lower().replace("_", "").replace("-", "") - ) - + quant_lower = quantization.lower().replace("_", "").replace("-", "") + # MLX only supports 4-bit and 8-bit if any(x in quant_lower for x in ["2", "3"]): - print_warning( - f"MLX only supports 4-bit and 8-bit, not {quantization}." - ) + print_warning(f"MLX only supports 4-bit and 8-bit, not {quantization}.") print_info("Using 4-bit quantization (smallest available).") convert_args["quantize"] = True convert_args["q_bits"] = 4 elif any(x in quant_lower for x in ["5", "6", "7"]): - print_warning( - f"MLX only supports 4-bit and 8-bit, not {quantization}." - ) + print_warning(f"MLX only supports 4-bit and 8-bit, not {quantization}.") print_info("Using 8-bit quantization (closest available).") convert_args["quantize"] = True convert_args["q_bits"] = 8 @@ -2210,16 +1993,16 @@ def _export_mlx( convert_args["q_bits"] = 4 if self.verbose: print_info("Using 4-bit MLX quantization (default)") - + # Run conversion convert(**convert_args) - + # Clean temp if not self._is_bnb_quantized(): shutil.rmtree(source_path, ignore_errors=True) - + progress.update(task, completed=100) - + except Exception as e: print_error(f"MLX conversion failed: {e}") raise @@ -2227,13 +2010,11 @@ def _export_mlx( # Fallback: save as HF format with instructions if self.verbose: print_warning("mlx-lm not installed. Saving as HuggingFace format.") - print_info( - "To convert to MLX: pip install mlx-lm && python -m mlx_lm.convert ..." - ) - + print_info("To convert to MLX: pip install mlx-lm && python -m mlx_lm.convert ...") + self.model.save_pretrained(output_path) self.tokenizer.save_pretrained(output_path) - + # Create README with conversion instructions readme_path = os.path.join(output_path, "CONVERT_TO_MLX.md") with open(readme_path, "w") as f: @@ -2242,16 +2023,14 @@ def _export_mlx( f.write("To convert to MLX format on Apple Silicon:\n\n") f.write("```bash\n") f.write("pip install mlx-lm\n") - f.write( - f"python -m mlx_lm.convert --hf-path {output_path} --mlx-path ./mlx_model\n" - ) + f.write(f"python -m mlx_lm.convert --hf-path {output_path} --mlx-path ./mlx_model\n") f.write("```\n") - + if self.verbose: print_success(f"MLX model exported to {output_path}") - + return output_path - + def __repr__(self) -> str: params = self.model.num_parameters() / 1e9 return ( @@ -2264,52 +2043,49 @@ def __repr__(self) -> str: f")" ) + def optimize_inference(self, backend: str = "triton", bits: int = 4): """ Optimize model for inference using high-performance kernels. - + Args: backend: Optimization backend ("triton") bits: Quantization bits (4 or 8) """ if backend == "triton": from ..kernels.triton import TritonQuantizedLinear, is_triton_available - if not is_triton_available(): - print_warning( - "Triton is not available or no GPU detected. Skipping optimization." - ) + print_warning("Triton is not available or no GPU detected. Skipping optimization.") return - + if self.verbose: print_header("Optimizing with Triton Kernels ⚡") - + count = self._replace_with_triton(self.model, bits) - + if self.verbose: print_success(f"Optimized {count} layers with Triton fused kernels!") - + def _replace_with_triton(self, module: nn.Module, bits: int) -> int: """Recursively replace Linear layers with TritonQuantizedLinear.""" from ..kernels.triton import TritonQuantizedLinear - count = 0 for name, child in module.named_children(): if isinstance(child, nn.Linear): # Replace with Triton Linear if self.verbose: print_info(f"Quantizing {name}...") - + quantized = TritonQuantizedLinear( - child.in_features, - child.out_features, - bits=bits, + child.in_features, + child.out_features, + bits=bits, bias=child.bias is not None, - group_size=128, + group_size=128 ) quantized.to(child.weight.device) quantized.quantize_from(child) - + setattr(module, name, quantized) count += 1 else: @@ -2325,7 +2101,7 @@ def register_architecture( ) -> None: """ Register a new architecture alias and optional explicit model class. - + Example: >>> register_architecture("my-new-model", base_model_type="llama") """ @@ -2349,10 +2125,10 @@ def turbo( ) -> TurboModel: """ Load and quantize any LLM in one line. - + This is the simplest way to use QuantLLM. Everything is automatically configured based on your hardware. - + Args: model: HuggingFace model name or local path bits: Override quantization bits (default: auto) @@ -2362,26 +2138,26 @@ def turbo( base_model_fallback: Retry with resolved base model config on first-load failure config: Shared export/push config (format, quantization, push_format, etc.) **kwargs: Additional options passed to from_pretrained - + Returns: TurboModel ready for use - + Examples: >>> # Simplest usage - everything automatic >>> model = turbo("meta-llama/Llama-3-8B") - >>> + >>> >>> # Override quantization >>> model = turbo("mistralai/Mistral-7B", bits=4) - >>> + >>> >>> # For long context >>> model = turbo("Qwen/Qwen2-72B", max_length=32768) - >>> + >>> >>> # Generate text >>> print(model.generate("Hello, world!")) - >>> + >>> >>> # Fine-tune >>> model.finetune("my_data.json") - >>> + >>> >>> # Export >>> model.export("gguf") """ From c752dfa35494b537f6c18de7782203b9acba7df3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 25 Apr 2026 13:31:00 +0000 Subject: [PATCH 6/6] Fix import order in fallback tests per review feedback Agent-Logs-Url: https://github.com/codewithdark-git/QuantLLM/sessions/8867f3b4-18ae-4207-b2e8-51444418c7aa Co-authored-by: codewithdark-git <144595403+codewithdark-git@users.noreply.github.com> --- tests/test_architecture_fallback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_architecture_fallback.py b/tests/test_architecture_fallback.py index 8371981..9382178 100644 --- a/tests/test_architecture_fallback.py +++ b/tests/test_architecture_fallback.py @@ -3,8 +3,8 @@ import transformers -import quantllm.core.turbo_model as turbo_model_module from quantllm.core.turbo_model import TurboModel +import quantllm.core.turbo_model as turbo_model_module class _DummySmartConfig(SimpleNamespace):