Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eevel committed Oct 25, 2022
1 parent 50cb225 commit e1745b3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
6 changes: 4 additions & 2 deletions src/twisted/conch/scripts/ckeygen.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,17 +203,19 @@ def _defaultPrivateKeySubtype(keyType):
return "PEM"


def _getKeyOrDefault(options):
def _getKeyOrDefault(options, input_collector=None):
"""
If C{options["filename"]} is None, prompt the user to enter a path
or attempt to set it to .ssh/id_rsa
"""
if input_collector is None:
input_collector = input
filename = options["filename"]
if not filename:
filename = os.path.expanduser("~/.ssh/id_rsa")
if platform.system() == "Windows":
filename = os.path.expandvars(R"%HOMEPATH %\.ssh\id_rsa")
filename = input("Enter file in which the key is (%s): " % filename) or filename
filename = input_collector("Enter file in which the key is (%s): " % filename) or filename
return filename


Expand Down
18 changes: 11 additions & 7 deletions src/twisted/conch/test/test_ckeygen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
"""

import getpass
from io import StringIO
import os
import subprocess
import sys
from io import StringIO

from twisted.conch.test.keydata import (
privateECDSA_openssh,
Expand Down Expand Up @@ -635,21 +636,24 @@ def test_useDefaultForKey(self):
L{options} will default to "~/.ssh/id_rsa" if the user doesn't
specify a key.
"""
input_prompts = []

def mock_input(*args):
return ""

import builtins
return input_prompts.append("")

self.patch(builtins, "input", mock_input)
options = {"filename": ""}
filename = _getKeyOrDefault(options)
filename = _getKeyOrDefault(options, mock_input)
self.assertEqual(
options["filename"],
"",
)
# Resolved path is an RSA key inside .ssh dir.
self.assertTrue(filename.endswith(os.path.join(".ssh", "id_rsa")))
# The user is prompted once to enter the path, since no path was
# provided via CLI.
self.assertEqual(1, len(input_prompts))
self.assertTrue(
"id_rsa" in filename,
"Enter file in which the key is (" in input_prompts
)

def test_displayPublicKeyHandleFileNotFound(self):
Expand Down

0 comments on commit e1745b3

Please sign in to comment.