Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/pquant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# flake8: noqa
backend = os.getenv("KERAS_BACKEND", "tensorflow")
if backend == "torch":
from . import configs, pruning_methods
from . import configs
from .core.hyperparameter_optimization import (
PQConfig,
ap_config,
Expand All @@ -19,7 +19,7 @@
pdp_config,
wanda_config,
)
from .core.torch import activations, layers, optimizers, quantizer
from .core.torch import activations, layers, optimizers, pruning_methods, quantizer
from .core.torch.layers import (
add_compression_layers,
apply_final_compression,
Expand Down Expand Up @@ -61,7 +61,7 @@
__all__ = _forwards

else:
from . import configs, pruning_methods
from . import configs
from .core.hyperparameter_optimization import (
PQConfig,
ap_config,
Expand All @@ -74,7 +74,7 @@
pdp_config,
wanda_config,
)
from .core.keras import activations, layers, quantizer
from .core.keras import activations, layers, pruning_methods, quantizer
from .core.keras.layers import (
add_compression_layers,
apply_final_compression,
Expand Down
21 changes: 0 additions & 21 deletions src/pquant/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,6 @@
PDPPruningModel,
WandaPruningModel,
)
from pquant.pruning_methods.constraint_functions import (
EqualityConstraint,
GreaterThanOrEqualConstraint,
LessThanOrEqualConstraint,
)
from pquant.pruning_methods.metric_functions import (
StructuredSparsityMetric,
UnstructuredSparsityMetric,
)

PRUNING_MODEL_REGISTRY = {
"cs": CSPruningModel,
Expand Down Expand Up @@ -53,15 +44,3 @@
CONFIG_FILE = "config.yaml"

N_JOBS = 1


METRIC_REGISTRY = {
"UnstructuredSparsity": UnstructuredSparsityMetric,
"StructuredSparsity": StructuredSparsityMetric,
}

CONSTRAINT_REGISTRY = {
"Equality": EqualityConstraint,
"LessThanOrEqual": LessThanOrEqualConstraint,
"GreaterThanOrEqual": GreaterThanOrEqualConstraint,
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def call(self, weight):
is_training = ops.logical_not(ops.logical_or(self.is_pretraining, self.is_finetuning))
self.mask.assign(ops.where(is_training, new_binary_mask, ops.convert_to_tensor(self.mask)))

sparse_weight = ops.sign(weight) * ops.reshape(autosparse_prune(w_t, self.alpha), weight.shape)
sparse_weight = ops.sign(weight) * ops.reshape(
autosparse_prune(w_t, ops.convert_to_tensor(self.alpha)), weight.shape
)

return ops.where(
self.is_pretraining,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,26 @@
import keras
from keras import ops

from pquant.core.constants import CONSTRAINT_REGISTRY, METRIC_REGISTRY
from pquant.core.keras.pruning_methods.constraint_functions import (
EqualityConstraint,
GreaterThanOrEqualConstraint,
LessThanOrEqualConstraint,
)
from pquant.core.keras.pruning_methods.metric_functions import (
StructuredSparsityMetric,
UnstructuredSparsityMetric,
)

METRIC_REGISTRY = {
"UnstructuredSparsity": UnstructuredSparsityMetric,
"StructuredSparsity": StructuredSparsityMetric,
}

CONSTRAINT_REGISTRY = {
"Equality": EqualityConstraint,
"LessThanOrEqual": LessThanOrEqualConstraint,
"GreaterThanOrEqual": GreaterThanOrEqualConstraint,
}

# -------------------------------------------------------------------
# MDMM Layer
Expand All @@ -28,6 +47,10 @@ def __init__(self, config, layer_type, *args, **kwargs):
self.constraint_layer = None
self._is_finetuning = False
self._is_pretraining = True
# TEMP: cache last penalty so calculate_additional_loss() works in
# custom training loops via get_model_losses(). Remove once the
# add_loss()/model.fit path is the only supported path.
self._last_penalty = None

def build(self, input_shape):
pruning_parameters = self.config.pruning_parameters
Expand Down Expand Up @@ -94,8 +117,11 @@ def call(self, weight):
self.mask.assign(ops.where(not_active, ops.convert_to_tensor(self.mask), hard_mask))

penalty = ops.sum(self.constraint_layer(weight))
self.add_loss(ops.where(not_active, ops.zeros_like(penalty), penalty))

gated_penalty = ops.where(not_active, ops.zeros_like(penalty), penalty)
self.add_loss(gated_penalty)
# TEMP: cache for calculate_additional_loss() — remove with the
# _last_penalty attribute once custom-loop callers move to model.losses.
self._last_penalty = gated_penalty
return ops.where(self.is_finetuning, weight * hard_mask, weight)

def get_hard_mask(self, weight=None):
Expand All @@ -109,7 +135,12 @@ def get_layer_sparsity(self, weight):

def calculate_additional_loss(self):
# Loss is added via self.add_loss() in call() for model.fit.
# For custom training loops, accumulate model.losses from the last forward pass instead.
# TEMP: also return the cached penalty so custom training loops using
# get_model_losses() see the constraint term. Remove this branch (and
# the _last_penalty cache) once those callers switch to model.losses;
# then this can revert to `return 0.0`.
if self._last_penalty is not None:
return self._last_penalty
return 0.0

def pre_epoch_function(self, epoch, total_epochs):
Expand Down
36 changes: 34 additions & 2 deletions src/pquant/core/keras/quantizer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from enum import Enum

import keras
from hgq.quantizer import Quantizer as HGQQuantizer
from hgq.quantizer import QuantizerConfig
from keras import ops

from pquant.core.quantizer_functions import create_quantizer
from quantizers import get_fixed_quantizer


@keras.saving.register_keras_serializable(package="PQuantML")
Expand Down Expand Up @@ -184,3 +185,34 @@ def get_config(self):
if self.use_hgq:
config.update({"quantizer": keras.saving.serialize_keras_object(self.quantizer)})
return config


def create_hgq_parameters_quantizer(k, i, f, overflow, round_mode, place, gamma=1e-8):
quantizer_config = QuantizerConfig(
q_type="kif", place=place, k0=k, i0=i, f0=f, overflow_mode=overflow, round_mode=round_mode, homogeneous_axis=()
)
return HGQQuantizer(config=quantizer_config)


def create_hgq_data_quantizer(k, i, f, overflow, round_mode, gamma=1e-8):
quantizer_config = QuantizerConfig(
q_type="kif",
place="datalane",
k0=k,
i0=i,
f0=f,
overflow_mode=overflow,
round_mode=round_mode,
homogeneous_axis=(0,),
)
return HGQQuantizer(config=quantizer_config)


def create_quantizer(k, i, f, overflow, round_mode, is_heterogeneous, is_data, place="datalane", gamma=1e-8):
if is_heterogeneous:
if is_data:
return create_hgq_data_quantizer(k, i, f, overflow, round_mode, gamma=gamma)
else:
return create_hgq_parameters_quantizer(k, i, f, overflow, round_mode, place, gamma=gamma)
else:
return get_fixed_quantizer(round_mode=round_mode, overflow_mode=overflow)
48 changes: 0 additions & 48 deletions src/pquant/core/quantizer_functions.py

This file was deleted.

Loading