Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
490 changes: 254 additions & 236 deletions docs/notebooks/structural_reliability.ipynb

Large diffs are not rendered by default.

16 changes: 9 additions & 7 deletions panel/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -203,34 +203,36 @@
<fast-text-field id="search-input" placeholder="search" onInput="hideCards(event.target.value)"></fast-text-field>
</section>

<section id="cards">
<section id="cards">
<ul class="cards-grid">
<!-- Sampling card moved to first position -->
<li class="card">
<a class="card-link" href="./simdec_app.html" id="simdec_app">
<a class="card-link" href="./sampling.html" id="sampling">
<fast-card class="gallery-item">
<object data="_static/thumbnails/simdec_app.png" type="image/png">
<object data="_static/thumbnails/sampling.png" type="image/png">
<svg class="card-image" xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" viewBox="0 0 16 16">
<path d="M2.5 4a.5.5 0 1 0 0-1 .5.5 0 0 0 0 1zm2-.5a.5.5 0 1 1-1 0 .5.5 0 0 1 1 0zm1 .5a.5.5 0 1 0 0-1 .5.5 0 0 0 0 1z"/>
<path d="M2 1a2 2 0 0 0-2 2v10a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V3a2 2 0 0 0-2-2H2zm13 2v2H1V3a1 1 0 0 1 1-1h12a1 1 0 0 1 1 1zM2 14a1 1 0 0 1-1-1V6h14v7a1 1 0 0 1-1 1H2z"/>
</svg>
</object>
<div class="card-content">
<h2 class="card-header">SimDec App</h2>
<h2 class="card-header">Sampling</h2>
</div>
</fast-card>
</a>
</li>
<!-- SimDec App card moved to second position -->
<li class="card">
<a class="card-link" href="./sampling.html" id="sampling">
<a class="card-link" href="./simdec_app.html" id="simdec_app">
<fast-card class="gallery-item">
<object data="_static/thumbnails/sampling.png" type="image/png">
<object data="_static/thumbnails/simdec_app.png" type="image/png">
<svg class="card-image" xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" viewBox="0 0 16 16">
<path d="M2.5 4a.5.5 0 1 0 0-1 .5.5 0 0 0 0 1zm2-.5a.5.5 0 1 1-1 0 .5.5 0 0 1 1 0zm1 .5a.5.5 0 1 0 0-1 .5.5 0 0 0 0 1z"/>
<path d="M2 1a2 2 0 0 0-2 2v10a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V3a2 2 0 0 0-2-2H2zm13 2v2H1V3a1 1 0 0 1 1-1h12a1 1 0 0 1 1 1zM2 14a1 1 0 0 1-1-1V6h14v7a1 1 0 0 1-1 1H2z"/>
</svg>
</object>
<div class="card-content">
<h2 class="card-header">Sampling</h2>
<h2 class="card-header">SimDec App</h2>
</div>
</fast-card>
</a>
Expand Down
7 changes: 5 additions & 2 deletions panel/simdec_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,11 @@ def filtered_si(sensitivity_indices_table, input_names):


def explained_variance_80(sensitivity_indices_table):
si = sensitivity_indices_table.value["Indices"]
pos_80 = bisect.bisect_right(np.cumsum(si), 0.8)
df = sensitivity_indices_table.value
df = df[df["Inputs"] != "Sum of Indices"]
si = df["Indices"].values
target = 0.8 * np.sum(si)
pos_80 = bisect.bisect_right(np.cumsum(si), target)

# pos_80 = max(2, pos_80)
# pos_80 = min(len(si), pos_80)
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ dashboard = [
"cryptography",
]

ipython = [
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok but could have been more descriptive.

"ipython"
]

test = [
"pytest",
"pytest-cov",
Expand All @@ -55,7 +59,7 @@ doc = [
]

dev = [
"simdec[doc,test,dashboard]",
"simdec[doc,test,dashboard, ipython]",
"watchfiles",
"pre-commit",
]
Expand Down
2 changes: 2 additions & 0 deletions src/simdec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from simdec.decomposition import *
from simdec.sensitivity_indices import *
from simdec.visualization import *
from simdec.heterogeneity_indices import *
Comment on lines 2 to +5
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit but should be ordered alphabetically

Suggested change
from simdec.decomposition import *
from simdec.sensitivity_indices import *
from simdec.visualization import *
from simdec.heterogeneity_indices import *
from simdec.decomposition import *
from simdec.heterogeneity_indices import *
from simdec.sensitivity_indices import *
from simdec.visualization import *


__all__ = [
"sensitivity_indices",
Expand All @@ -11,4 +12,5 @@
"two_output_visualization",
"tableau",
"palette",
"heterogeneity_indices",
]
239 changes: 239 additions & 0 deletions src/simdec/heterogeneity_indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
from dataclasses import dataclass
import logging

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import simdec as sd

logger = logging.getLogger(__name__)

__all__ = ["heterogeneity_indices", "plot_heterogeneity"]


@dataclass
class HeterogeneityResult:
summary: pd.DataFrame
regional_profiles: pd.DataFrame
split_name: str


def heterogeneity_indices(
output: pd.Series,
inputs: pd.DataFrame,
split_variable: str | pd.Series,
n_subdivisions: int | None = None,
plot: bool = False,
) -> HeterogeneityResult:
"""Heterogeneity indices.

Compute sensitivity-based heterogeneity across subdivisions
of a variable.

Parameters
----------
output : pd.Series
Model output vector.
inputs : pd.DataFrame
Input/feature matrix.
split_variable : str or pd.Series
Variable to split on. If string, must be a column in 'inputs'.
n_subdivisions : int, optional
Number of regions for continuous variables. Defaults to 4.
plot : bool, default False
If True, displays a stacked bar chart of regional sensitivity profiles
by calling :func:`plot_heterogeneity`. The chart shows variance
contributions of each input across subdivisions of ``split_variable``,
ranked by global sensitivity indices. To capture the returned
``matplotlib.axes.Axes`` object, call :func:`plot_heterogeneity`
directly on the result instead.

Returns
-------
res : HeterogeneityResult
An object with attributes:

summary : DataFrame
A summary of calculated heterogeneity indices.
regional_profiles : DataFrame
Regional sensitivity indices for each input across subdivisions.
split_name : str
The name of the variable used to split the data.

"""
y = pd.Series(output).reset_index(drop=True)
X = pd.DataFrame(inputs).reset_index(drop=True)

if isinstance(split_variable, str):
if split_variable not in X.columns:
raise ValueError(f"'{split_variable}' not found in inputs.")
z = X[split_variable].reset_index(drop=True)
split_name = split_variable
else:
z = pd.Series(split_variable).reset_index(drop=True)
split_name = getattr(split_variable, "name", "split_variable")

unique_vals = z.dropna().unique()
n_unique = len(unique_vals)

# Determine if variable is categorical/binary
is_categorical = (
isinstance(z.dtype, pd.CategoricalDtype)
or pd.api.types.is_object_dtype(z)
or pd.api.types.is_string_dtype(z)
or pd.api.types.is_bool_dtype(z)
or n_unique <= 2
)

if is_categorical:
regions = z.astype("category")
else:
q = n_subdivisions if n_subdivisions is not None else 4
try:
regions = pd.qcut(z, q=q, duplicates="drop")
except ValueError as e:
raise ValueError(
f"Failed to bin '{split_name}' into {q} quantiles: {e}"
) from e

regional_profiles = []
skipped = []

for region in regions.cat.categories:
mask = regions == region
n_in_region = mask.sum()

if n_in_region < 10:
# Need enough samples for meaningful sensitivity indices
skipped.append((region, n_in_region, "too few samples (< 10)"))
continue

X_sub = X.loc[mask]
y_sub = y.loc[mask]

# Skip if output has zero or near-zero variance in this region
if y_sub.var() < 1e-12:
skipped.append((region, n_in_region, "output variance ≈ 0"))
continue

try:
res = sd.sensitivity_indices(inputs=X_sub, output=y_sub)
si_vals = np.asarray(res.si).ravel()

# Guard against NaN/Inf from degenerate sensitivity computation
if not np.all(np.isfinite(si_vals)):
skipped.append((region, n_in_region, "non-finite SI values"))
continue

si_region = pd.Series(si_vals, index=X.columns, name=region)
regional_profiles.append(si_region)

except Exception as e:
skipped.append((region, n_in_region, f"exception: {e}"))
continue

if skipped:
logger.info("Skipped %d region(s) of '%s':", len(skipped), split_name)
for reg, n, reason in skipped:
logger.info(" - region=%r, n=%d, reason=%s", reg, n, reason)

if len(regional_profiles) < 2:
total_regions = len(regions.cat.categories)
valid = len(regional_profiles)
raise ValueError(
f"Not enough valid subdivisions to compute heterogeneity: "
f"{valid}/{total_regions} regions passed all checks for '{split_name}'.\n"
f"Skipped regions:\n"
"\n".join(f" {r!r}: n={n}, {reason} " for r, n, reason in skipped),
"\n\nTry: (1) reducing n_subdivisions, "
"(2) using a different split_variable, or "
"(3) ensuring more samples per region.",
)

regional_si = pd.concat(regional_profiles, axis=1)

res_global = sd.sensitivity_indices(inputs=X, output=y)
overall_si = pd.Series(
np.asarray(res_global.si).ravel(),
index=X.columns,
name="Overall_SI",
)

# Heterogeneity = 2 × population std dev across regions
hetero_scores = 2 * regional_si.std(axis=1, ddof=0)
total_hetero = hetero_scores.mean()

hetero_col_name = f"Heterogeneity (across {split_name})"
summary = pd.DataFrame(
{"Overall_SI": overall_si, hetero_col_name: hetero_scores}
).sort_values(by=hetero_col_name, ascending=False)
summary.loc["SUM / TOTAL"] = [overall_si.sum(), total_hetero]

result = HeterogeneityResult(summary, regional_si, split_name)

if plot:
Comment thread
tupui marked this conversation as resolved.
plot_heterogeneity(result)

return result


def plot_heterogeneity(result: HeterogeneityResult, ax: plt.Axes = None) -> plt.Axes:
"""Plot regional sensitivity profiles.

Parameters
----------
result : HeterogeneityResult
The result object from heterogeneity_indices.
ax : matplotlib.axes.Axes, optional
Existing axes to plot on.

Returns
-------
ax : matplotlib.axes.Axes
The axes with the plot.

"""
summary = result.summary
regional_si = result.regional_profiles
split_name = result.split_name

plot_order = summary.index[summary.index != "SUM / TOTAL"]
plot_order = (
summary.loc[plot_order].sort_values(by="Overall_SI", ascending=False).index
)

cmap = plt.colormaps["terrain"]
colors = [cmap(i) for i in np.linspace(0.05, 0.95, len(regional_si.index))]

data_to_plot = regional_si.loc[plot_order].T

if ax is None:
_, ax = plt.subplots(figsize=(10, 6))

data_to_plot.plot(
kind="bar",
stacked=True,
ax=ax,
color=colors,
edgecolor="white",
width=0.8,
)

ax.set_title(f"Sensitivity Profiles across {split_name}", fontsize=14)
ax.set_ylabel("Variance Contribution", fontsize=12)
ax.set_xlabel(f"Regions of {split_name}", fontsize=12)

ax.legend(
title="Inputs (Ranked by Global SI)",
bbox_to_anchor=(1.05, 1),
loc="upper left",
)

ax.tick_params(axis="x", labelrotation=45)
ax.grid(axis="y", linestyle="--", alpha=0.7)

if plt.get_backend().lower() != "agg":
plt.tight_layout()

return ax
29 changes: 24 additions & 5 deletions src/simdec/sensitivity_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class SensitivityAnalysisResult:


def sensitivity_indices(
inputs: pd.DataFrame | np.ndarray, output: pd.DataFrame | np.ndarray
inputs: pd.DataFrame | np.ndarray,
output: pd.DataFrame | np.ndarray,
print_indices: bool = False,
) -> SensitivityAnalysisResult:
"""Sensitivity indices.

Expand All @@ -50,6 +52,8 @@ def sensitivity_indices(
Input variables.
output : ndarray or DataFrame of shape (n_runs, 1)
Target variable.
print_indices : bool, default False
If True, displays computed indices.

Returns
-------
Expand Down Expand Up @@ -97,11 +101,18 @@ def sensitivity_indices(
"""
# Handle inputs conversion
if isinstance(inputs, pd.DataFrame):
cat_columns = inputs.select_dtypes(["category", "O"]).columns
inputs[cat_columns] = inputs[cat_columns].apply(
lambda x: x.astype("category").cat.codes
)
var_names = inputs.columns.tolist()
cat_cols = inputs.select_dtypes(include=["category", "O", "string"]).columns
if not cat_cols.empty:
inputs = inputs.copy() # Avoid SettingWithCopyWarning
inputs[cat_cols] = inputs[cat_cols].apply(
lambda x: x.astype("category").cat.codes
)
inputs = inputs.to_numpy()
else:
inputs = np.asarray(inputs)
# Fallback names if it's just a numpy array
var_names = [f"x{i}" for i in range(inputs.shape[1])]

# Handle output conversion first, then flatten
if isinstance(output, (pd.DataFrame, pd.Series)):
Expand Down Expand Up @@ -181,4 +192,12 @@ def sensitivity_indices(
for k in range(n_factors):
si[k] = foe[k] + (soe[:, k].sum() / 2)

if print_indices:
df_foe = pd.DataFrame(foe, index=var_names, columns=["First-order effect"])
df_soe = pd.DataFrame(soe, index=var_names, columns=var_names)
df_si = pd.DataFrame(si, index=var_names, columns=["Combined effect"])

df_indices = pd.concat([df_foe, df_soe, df_si], axis=1)
print(f"\n{df_indices}\n")

return SensitivityAnalysisResult(si, foe, soe)
Loading