Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 153 additions & 8 deletions src/telemetry_window_demo/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import argparse
import math
from collections.abc import Mapping, Sequence
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -29,6 +30,16 @@
"source_spread_spike",
"rare_event_repeat",
)
RUN_RULE_CONFIG_FIELDS = {
"high_error_rate": frozenset(("threshold", "severity")),
"login_fail_burst": frozenset(("threshold", "severity")),
"high_severity_spike": frozenset(("threshold", "severity")),
"persistent_high_error": frozenset(
("threshold", "consecutive_windows", "severity")
),
"source_spread_spike": frozenset(("absolute_threshold", "multiplier", "severity")),
"rare_event_repeat": frozenset(("threshold", "event_types", "severity")),
}


def main() -> None:
Expand Down Expand Up @@ -293,6 +304,15 @@ def _validate_rules_config(raw_rules_config: Any) -> dict[str, Any]:
if raw_rules_config is None
else dict(_optional_mapping(raw_rules_config, "rules"))
)
allowed_rule_keys = {"cooldown_seconds", *RUN_RULE_SECTION_NAMES}
unknown_rule_keys = sorted(
str(key) for key in rules_config if key not in allowed_rule_keys
)
if unknown_rule_keys:
raise ValueError(
"Unknown config field(s) under 'rules': " + ", ".join(unknown_rule_keys)
)

rules_config["cooldown_seconds"] = _int_config_value(
rules_config.get("cooldown_seconds", 0),
"rules.cooldown_seconds",
Expand All @@ -301,18 +321,101 @@ def _validate_rules_config(raw_rules_config: Any) -> dict[str, Any]:

for rule_name in RUN_RULE_SECTION_NAMES:
if rule_name in rules_config:
rules_config[rule_name] = dict(
_optional_mapping(rules_config[rule_name], f"rules.{rule_name}")
rule_config = dict(
_optional_mapping(
rules_config[rule_name],
f"rules.{rule_name}",
)
)
rules_config[rule_name] = _validate_rule_section_config(
rule_name,
rule_config,
)

return rules_config


rare_event_repeat = rules_config.get("rare_event_repeat")
if isinstance(rare_event_repeat, dict) and "event_types" in rare_event_repeat:
rare_event_repeat["event_types"] = _string_sequence(
rare_event_repeat["event_types"],
"rules.rare_event_repeat.event_types",
def _validate_rule_section_config(
rule_name: str,
rule_config: dict[str, Any],
) -> dict[str, Any]:
allowed_fields = RUN_RULE_CONFIG_FIELDS[rule_name]
unknown_fields = sorted(
str(key) for key in rule_config if key not in allowed_fields
)
if unknown_fields:
raise ValueError(
f"Unknown config field(s) under 'rules.{rule_name}': "
+ ", ".join(unknown_fields)
)

return rules_config
if "severity" in rule_config:
rule_config["severity"] = _string_config_value(
rule_config["severity"],
f"rules.{rule_name}.severity",
)

if rule_name == "high_error_rate":
_normalize_optional_float(
rule_config,
"threshold",
"rules.high_error_rate.threshold",
minimum=0.0,
)
elif rule_name == "login_fail_burst":
_normalize_optional_int(
rule_config,
"threshold",
"rules.login_fail_burst.threshold",
minimum=1,
)
elif rule_name == "high_severity_spike":
_normalize_optional_int(
rule_config,
"threshold",
"rules.high_severity_spike.threshold",
minimum=1,
)
elif rule_name == "persistent_high_error":
_normalize_optional_float(
rule_config,
"threshold",
"rules.persistent_high_error.threshold",
minimum=0.0,
)
_normalize_optional_int(
rule_config,
"consecutive_windows",
"rules.persistent_high_error.consecutive_windows",
minimum=1,
)
elif rule_name == "source_spread_spike":
_normalize_optional_int(
rule_config,
"absolute_threshold",
"rules.source_spread_spike.absolute_threshold",
minimum=1,
)
_normalize_optional_float(
rule_config,
"multiplier",
"rules.source_spread_spike.multiplier",
minimum=1.0,
)
elif rule_name == "rare_event_repeat":
_normalize_optional_int(
rule_config,
"threshold",
"rules.rare_event_repeat.threshold",
minimum=1,
)
if "event_types" in rule_config:
rule_config["event_types"] = _string_sequence(
rule_config["event_types"],
"rules.rare_event_repeat.event_types",
)

return rule_config


def _optional_mapping(value: Any, field_name: str) -> Mapping[str, Any]:
Expand Down Expand Up @@ -349,6 +452,48 @@ def _int_config_value(value: Any, field_name: str, *, minimum: int) -> int:
return parsed


def _float_config_value(value: Any, field_name: str, *, minimum: float) -> float:
if isinstance(value, bool):
raise ValueError(f"Config field '{field_name}' must be a number.")
if isinstance(value, (int, float)):
parsed = float(value)
elif isinstance(value, str):
try:
parsed = float(value.strip())
except ValueError as exc:
raise ValueError(f"Config field '{field_name}' must be a number.") from exc
else:
raise ValueError(f"Config field '{field_name}' must be a number.")

if not math.isfinite(parsed):
raise ValueError(f"Config field '{field_name}' must be a finite number.")
if parsed < minimum:
raise ValueError(f"Config field '{field_name}' must be at least {minimum:g}.")
return parsed


def _normalize_optional_int(
config: dict[str, Any],
key: str,
field_name: str,
*,
minimum: int,
) -> None:
if key in config:
config[key] = _int_config_value(config[key], field_name, minimum=minimum)


def _normalize_optional_float(
config: dict[str, Any],
key: str,
field_name: str,
*,
minimum: float,
) -> None:
if key in config:
config[key] = _float_config_value(config[key], field_name, minimum=minimum)


def _optional_string_sequence(value: Any, field_name: str) -> list[str] | None:
if value is None:
return None
Expand Down
54 changes: 54 additions & 0 deletions tests/test_run_config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,57 @@ def test_run_config_rejects_string_rare_event_types(tmp_path) -> None:

with pytest.raises(ValueError, match="rules.rare_event_repeat.event_types"):
run_command(Namespace(config=str(config_path)))


def test_run_config_rejects_unknown_rule_name(tmp_path) -> None:
config = _base_config(tmp_path)
config["rules"]["high_error_rates"] = {"threshold": 0.30}
config_path = _write_config(tmp_path, config)

with pytest.raises(ValueError, match="high_error_rates"):
run_command(Namespace(config=str(config_path)))


def test_run_config_rejects_unknown_rule_field(tmp_path) -> None:
config = _base_config(tmp_path)
config["rules"]["high_error_rate"]["thresholds"] = 0.30
config_path = _write_config(tmp_path, config)

with pytest.raises(ValueError, match="rules.high_error_rate"):
run_command(Namespace(config=str(config_path)))


def test_run_config_rejects_boolean_rule_threshold(tmp_path) -> None:
config = _base_config(tmp_path)
config["rules"]["high_error_rate"]["threshold"] = True
config_path = _write_config(tmp_path, config)

with pytest.raises(ValueError, match="rules.high_error_rate.threshold"):
run_command(Namespace(config=str(config_path)))


def test_run_config_rejects_non_positive_count_threshold(tmp_path) -> None:
config = _base_config(tmp_path)
config["rules"]["login_fail_burst"]["threshold"] = 0
config_path = _write_config(tmp_path, config)

with pytest.raises(ValueError, match="rules.login_fail_burst.threshold"):
run_command(Namespace(config=str(config_path)))


def test_run_config_rejects_source_spread_multiplier_below_one(tmp_path) -> None:
config = _base_config(tmp_path)
config["rules"]["source_spread_spike"]["multiplier"] = 0.5
config_path = _write_config(tmp_path, config)

with pytest.raises(ValueError, match="rules.source_spread_spike.multiplier"):
run_command(Namespace(config=str(config_path)))


def test_run_config_rejects_empty_rule_severity(tmp_path) -> None:
config = _base_config(tmp_path)
config["rules"]["persistent_high_error"]["severity"] = ""
config_path = _write_config(tmp_path, config)

with pytest.raises(ValueError, match="rules.persistent_high_error.severity"):
run_command(Namespace(config=str(config_path)))
Loading