From b271a2f2a21ce634d5d505974ebab221bd88beaf Mon Sep 17 00:00:00 2001 From: Christian Berendt Date: Wed, 6 May 2026 21:13:55 +0200 Subject: [PATCH] sonic: add --refresh-host-key option to refresh SSH host keys after redeployment Adds an optional --refresh-host-key flag to all SSH-using sonic commands (load, backup, ztp, reload, reboot, reset, show). When set, existing known_hosts entries for the target host are removed before connecting, allowing the new host key to be accepted after a switch has been redeployed and the host key changed. AI-assisted: Claude Code Signed-off-by: Christian Berendt --- osism/commands/sonic.py | 62 +++++++-- tests/unit/commands/test_sonic_ssh.py | 193 ++++++++++++++++++++++++++ 2 files changed, 247 insertions(+), 8 deletions(-) create mode 100644 tests/unit/commands/test_sonic_ssh.py diff --git a/osism/commands/sonic.py b/osism/commands/sonic.py index f7205265..551e8fff 100644 --- a/osism/commands/sonic.py +++ b/osism/commands/sonic.py @@ -15,6 +15,7 @@ from osism.utils.ssh import ( cleanup_ssh_known_hosts_for_node, ensure_known_hosts_file, + remove_known_hosts_entries, KNOWN_HOSTS_PATH, ) @@ -98,7 +99,20 @@ def _get_ssh_connection_details(self, config_context, device, hostname): return ssh_host, ssh_username - def _create_ssh_connection(self, ssh_host, ssh_username): + @staticmethod + def _add_refresh_host_key_argument(parser): + """Add --refresh-host-key option for forcing SSH host key refresh.""" + parser.add_argument( + "--refresh-host-key", + dest="refresh_host_key", + action="store_true", + help=( + "Remove the existing SSH host key entry before connecting. " + "Use after a switch redeployment when the host key has changed." + ), + ) + + def _create_ssh_connection(self, ssh_host, ssh_username, refresh_host_key=False): """Create and return SSH connection""" ssh_key_path = "/ansible/secrets/id_rsa.operator" @@ -112,6 +126,17 @@ def _create_ssh_connection(self, ssh_host, ssh_username): f"Could not initialize {KNOWN_HOSTS_PATH}, continuing with AutoAddPolicy" ) + if refresh_host_key: + logger.info( + f"Refreshing SSH host key: removing known_hosts entries for {ssh_host}" + ) + try: + remove_known_hosts_entries(ssh_host, KNOWN_HOSTS_PATH) + except Exception as e: + logger.warning( + f"Failed to refresh SSH host key for {ssh_host}, continuing: {e}" + ) + ssh = paramiko.SSHClient() # Load system host keys from centralized known_hosts file try: @@ -293,6 +318,7 @@ def get_parser(self, prog_name): type=str, help="Hostname of the SONiC switch to load configuration", ) + self._add_refresh_host_key_argument(parser) return parser def take_action(self, parsed_args): @@ -329,7 +355,9 @@ def take_action(self, parsed_args): ) # Create SSH connection - ssh = self._create_ssh_connection(ssh_host, ssh_username) + ssh = self._create_ssh_connection( + ssh_host, ssh_username, parsed_args.refresh_host_key + ) if not ssh: return 1 @@ -391,6 +419,7 @@ def get_parser(self, prog_name): parser.add_argument( "hostname", type=str, help="Hostname of the SONiC switch to backup" ) + self._add_refresh_host_key_argument(parser) return parser def take_action(self, parsed_args): @@ -420,7 +449,9 @@ def take_action(self, parsed_args): ) # Create SSH connection - ssh = self._create_ssh_connection(ssh_host, ssh_username) + ssh = self._create_ssh_connection( + ssh_host, ssh_username, parsed_args.refresh_host_key + ) if not ssh: return 1 @@ -467,6 +498,7 @@ def get_parser(self, prog_name): parser.add_argument( "hostname", type=str, help="Hostname of the SONiC switch to manage ZTP" ) + self._add_refresh_host_key_argument(parser) return parser def take_action(self, parsed_args): @@ -501,7 +533,9 @@ def take_action(self, parsed_args): ) # Create SSH connection - ssh = self._create_ssh_connection(ssh_host, ssh_username) + ssh = self._create_ssh_connection( + ssh_host, ssh_username, parsed_args.refresh_host_key + ) if not ssh: return 1 @@ -555,6 +589,7 @@ def get_parser(self, prog_name): parser.add_argument( "hostname", type=str, help="Hostname of the SONiC switch to reload" ) + self._add_refresh_host_key_argument(parser) return parser def take_action(self, parsed_args): @@ -591,7 +626,9 @@ def take_action(self, parsed_args): ) # Create SSH connection - ssh = self._create_ssh_connection(ssh_host, ssh_username) + ssh = self._create_ssh_connection( + ssh_host, ssh_username, parsed_args.refresh_host_key + ) if not ssh: return 1 @@ -664,6 +701,7 @@ def get_parser(self, prog_name): parser.add_argument( "hostname", type=str, help="Hostname of the SONiC switch to reboot" ) + self._add_refresh_host_key_argument(parser) return parser def take_action(self, parsed_args): @@ -690,7 +728,9 @@ def take_action(self, parsed_args): logger.info(f"Connecting to {hostname} ({ssh_host}) to reboot SONiC switch") # Create SSH connection - ssh = self._create_ssh_connection(ssh_host, ssh_username) + ssh = self._create_ssh_connection( + ssh_host, ssh_username, parsed_args.refresh_host_key + ) if not ssh: return 1 @@ -739,6 +779,7 @@ def get_parser(self, prog_name): action="store_true", help="Force factory reset without confirmation prompt", ) + self._add_refresh_host_key_argument(parser) return parser def take_action(self, parsed_args): @@ -780,7 +821,9 @@ def take_action(self, parsed_args): ) # Create SSH connection - ssh = self._create_ssh_connection(ssh_host, ssh_username) + ssh = self._create_ssh_connection( + ssh_host, ssh_username, parsed_args.refresh_host_key + ) if not ssh: return 1 @@ -873,6 +916,7 @@ def get_parser(self, prog_name): nargs="*", help="Show command and parameters to execute (e.g., 'interfaces', 'version', 'ip route'). If not specified, executes 'show' to display available commands", ) + self._add_refresh_host_key_argument(parser) return parser def take_action(self, parsed_args): @@ -905,7 +949,9 @@ def take_action(self, parsed_args): logger.info(f"Executing command on {hostname} ({ssh_host}): {show_command}") # Create SSH connection - ssh = self._create_ssh_connection(ssh_host, ssh_username) + ssh = self._create_ssh_connection( + ssh_host, ssh_username, parsed_args.refresh_host_key + ) if not ssh: return 1 diff --git a/tests/unit/commands/test_sonic_ssh.py b/tests/unit/commands/test_sonic_ssh.py new file mode 100644 index 00000000..7b83f38b --- /dev/null +++ b/tests/unit/commands/test_sonic_ssh.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the --refresh-host-key option on SONiC SSH-using commands.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from osism.commands import sonic + +SSH_COMMAND_CLASSES = [ + sonic.Load, + sonic.Backup, + sonic.Ztp, + sonic.Reload, + sonic.Reboot, + sonic.Reset, + sonic.Show, +] + + +def _build_parser_for(cmd_cls): + """Instantiate a cliff Command and return its argparse parser.""" + cmd = cmd_cls(MagicMock(), MagicMock()) + return cmd.get_parser("test") + + +# --- Parser wiring --- + + +@pytest.mark.parametrize("cmd_cls", SSH_COMMAND_CLASSES, ids=lambda c: c.__name__) +def test_parser_registers_refresh_host_key_option(cmd_cls): + parser = _build_parser_for(cmd_cls) + actions = {a.dest: a for a in parser._actions} + assert "refresh_host_key" in actions + action = actions["refresh_host_key"] + assert "--refresh-host-key" in action.option_strings + assert action.default is False + # store_true: nargs=0, const=True + assert action.const is True + assert action.nargs == 0 + + +@pytest.mark.parametrize("cmd_cls", SSH_COMMAND_CLASSES, ids=lambda c: c.__name__) +def test_parser_default_is_false(cmd_cls): + parser = _build_parser_for(cmd_cls) + # All these commands take a hostname positional, plus Show takes nargs="*", + # plus Ztp takes a leading action. Build a minimal valid argv per command. + if cmd_cls is sonic.Ztp: + argv = ["status", "switch1"] + else: + argv = ["switch1"] + args = parser.parse_args(argv) + assert args.refresh_host_key is False + + +@pytest.mark.parametrize("cmd_cls", SSH_COMMAND_CLASSES, ids=lambda c: c.__name__) +def test_parser_sets_true_when_flag_passed(cmd_cls): + parser = _build_parser_for(cmd_cls) + if cmd_cls is sonic.Ztp: + argv = ["status", "switch1", "--refresh-host-key"] + else: + argv = ["switch1", "--refresh-host-key"] + args = parser.parse_args(argv) + assert args.refresh_host_key is True + + +# --- _create_ssh_connection behavior --- + + +class _ConcreteSonicCommand(sonic.SonicCommandBase): + """Concrete subclass so we can instantiate the abstract base in tests.""" + + def take_action(self, parsed_args): # pragma: no cover - not exercised + return 0 + + +def _make_base(): + return _ConcreteSonicCommand(MagicMock(), MagicMock()) + + +@patch("osism.commands.sonic.paramiko") +@patch("osism.commands.sonic.remove_known_hosts_entries") +@patch("osism.commands.sonic.ensure_known_hosts_file", return_value=True) +@patch("osism.commands.sonic.os.path.exists", return_value=True) +def test_create_ssh_connection_calls_remove_when_refresh_true( + _exists, _ensure, mock_remove, _paramiko +): + base = _make_base() + result = base._create_ssh_connection("10.0.0.1", "admin", refresh_host_key=True) + assert result is not None + mock_remove.assert_called_once_with("10.0.0.1", sonic.KNOWN_HOSTS_PATH) + + +@patch("osism.commands.sonic.paramiko") +@patch("osism.commands.sonic.remove_known_hosts_entries") +@patch("osism.commands.sonic.ensure_known_hosts_file", return_value=True) +@patch("osism.commands.sonic.os.path.exists", return_value=True) +def test_create_ssh_connection_skips_remove_when_refresh_false( + _exists, _ensure, mock_remove, _paramiko +): + base = _make_base() + base._create_ssh_connection("10.0.0.1", "admin", refresh_host_key=False) + mock_remove.assert_not_called() + + +@patch("osism.commands.sonic.paramiko") +@patch("osism.commands.sonic.remove_known_hosts_entries") +@patch("osism.commands.sonic.ensure_known_hosts_file", return_value=True) +@patch("osism.commands.sonic.os.path.exists", return_value=True) +def test_create_ssh_connection_default_does_not_refresh( + _exists, _ensure, mock_remove, _paramiko +): + """Default value preserves prior (non-refreshing) behavior.""" + base = _make_base() + base._create_ssh_connection("10.0.0.1", "admin") + mock_remove.assert_not_called() + + +@patch("osism.commands.sonic.paramiko") +@patch("osism.commands.sonic.remove_known_hosts_entries") +@patch("osism.commands.sonic.ensure_known_hosts_file", return_value=True) +@patch("osism.commands.sonic.os.path.exists", return_value=False) +def test_create_ssh_connection_returns_none_when_key_missing( + _exists, _ensure, mock_remove, _paramiko +): + """Missing private key short-circuits before any host-key handling.""" + base = _make_base() + result = base._create_ssh_connection("10.0.0.1", "admin", refresh_host_key=True) + assert result is None + mock_remove.assert_not_called() + + +@patch("osism.commands.sonic.paramiko") +@patch( + "osism.commands.sonic.remove_known_hosts_entries", + side_effect=PermissionError("denied"), +) +@patch("osism.commands.sonic.ensure_known_hosts_file", return_value=True) +@patch("osism.commands.sonic.os.path.exists", return_value=True) +def test_create_ssh_connection_continues_when_refresh_fails( + _exists, _ensure, _mock_remove, mock_paramiko +): + """If the host-key refresh raises, we log and proceed with the connection.""" + base = _make_base() + result = base._create_ssh_connection("10.0.0.1", "admin", refresh_host_key=True) + assert result is not None + mock_paramiko.SSHClient.return_value.connect.assert_called_once() + + +# --- take_action forwards refresh_host_key to _create_ssh_connection --- + + +def _build_parsed_args(cmd_cls, refresh_host_key): + """Build a minimal parsed_args mock for each command's take_action.""" + parsed_args = MagicMock() + parsed_args.hostname = "switch1" + parsed_args.refresh_host_key = refresh_host_key + if cmd_cls is sonic.Ztp: + parsed_args.action = "status" + if cmd_cls is sonic.Reset: + # Skip the interactive confirmation prompt. + parsed_args.force = True + if cmd_cls is sonic.Show: + parsed_args.command = [] + return parsed_args + + +@pytest.mark.parametrize("cmd_cls", SSH_COMMAND_CLASSES, ids=lambda c: c.__name__) +@pytest.mark.parametrize("refresh", [True, False], ids=["refresh", "no_refresh"]) +@patch("osism.commands.sonic.utils") +def test_take_action_forwards_refresh_host_key(mock_utils, cmd_cls, refresh): + """Each SSH-using command must forward parsed_args.refresh_host_key.""" + cmd = cmd_cls(MagicMock(), MagicMock()) + + # Stub the helpers take_action calls before _create_ssh_connection. + cmd._get_device_from_netbox = MagicMock(return_value=MagicMock()) + cmd._get_config_context = MagicMock(return_value={"management": {}}) + cmd._save_config_context = MagicMock(return_value="/tmp/cfg.json") + cmd._get_ssh_connection_details = MagicMock(return_value=("10.0.0.1", "admin")) + # Returning None short-circuits take_action right after _create_ssh_connection. + cmd._create_ssh_connection = MagicMock(return_value=None) + + parsed_args = _build_parsed_args(cmd_cls, refresh) + cmd.take_action(parsed_args) + + cmd._create_ssh_connection.assert_called_once() + args, kwargs = cmd._create_ssh_connection.call_args + # refresh_host_key may be passed positionally (3rd arg) or as a keyword. + if "refresh_host_key" in kwargs: + assert kwargs["refresh_host_key"] is refresh + else: + assert args[2] is refresh