Skip to content

Commit

Permalink
Add SSH tunnel support (#1301)
Browse files Browse the repository at this point in the history
* Add initial sshtunnel support

* Force CI to rerun.

* Fix unit test for 3.6.

* Black.

Co-authored-by: Irina Truong <i.chernyavska@gmail.com>
  • Loading branch information
sweenu and j-bennet authored Feb 18, 2022
1 parent 78843ac commit ed9d123
Show file tree
Hide file tree
Showing 9 changed files with 234 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
- name: Install requirements
run: |
pip install -U pip setuptools
pip install --no-cache-dir .
pip install --no-cache-dir ".[sshtunnel]"
pip install -r requirements-dev.txt
pip install keyrings.alt>=3.1
Expand Down
3 changes: 2 additions & 1 deletion AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ Contributors:
* Kevin Marsh (kevinmarsh)
* Eero Ruohola (ruohola)
* Miroslav Šedivý (eumiro)
* Eric R Young (ERYoung11)
* Eric R Young (ERYoung11)
* Paweł Sacawa (psacawa)
* Bruno Inec (sweenu)

Creator:
--------
Expand Down
1 change: 1 addition & 0 deletions changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Features:

* Add `max_field_width` setting to config, to enable more control over field truncation ([related issue](https://github.com/dbcli/pgcli/issues/1250)).
* Re-run last query via bare `\watch`. (Thanks: `Saif Hakim`_)
* Add optional support for automatically creating an SSH tunnel to a machine with access to the remote database ([related issue](https://github.com/dbcli/pgcli/issues/459)).

Bug fixes:
----------
Expand Down
82 changes: 78 additions & 4 deletions pgcli/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import platform
import warnings
from os.path import expanduser

from configobj import ConfigObj, ParseError
from pgspecial.namedqueries import NamedQueries
from .config import skip_initial_comment

warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")

import atexit
import os
import re
import sys
Expand All @@ -21,6 +21,8 @@
import itertools
import platform
from time import time, sleep
from typing import Optional
from urllib.parse import urlparse

keyring = None # keyring will be loaded later

Expand Down Expand Up @@ -78,12 +80,21 @@

from getpass import getuser
from psycopg2 import OperationalError, InterfaceError
from psycopg2.extensions import make_dsn, parse_dsn
import psycopg2

from collections import namedtuple

from textwrap import dedent

try:
import sshtunnel

SSH_TUNNEL_SUPPORT = True
except ImportError:
SSH_TUNNEL_SUPPORT = False


# Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output
COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))")
DEFAULT_MAX_FIELD_WIDTH = 500
Expand Down Expand Up @@ -168,8 +179,8 @@ def __init__(
prompt_dsn=None,
auto_vertical_output=False,
warn=None,
ssh_tunnel_url: Optional[str] = None,
):

self.force_passwd_prompt = force_passwd_prompt
self.never_passwd_prompt = never_passwd_prompt
self.pgexecute = pgexecute
Expand Down Expand Up @@ -275,6 +286,9 @@ def __init__(

self.prompt_app = None

self.ssh_tunnel_url = ssh_tunnel_url
self.ssh_tunnel = None

def quit(self):
raise PgCliQuitError

Expand Down Expand Up @@ -585,6 +599,50 @@ def should_ask_for_password(exc):
return True
return False

if self.ssh_tunnel_url:
# We add the protocol as urlparse doesn't find it by itself
if "://" not in self.ssh_tunnel_url:
self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}"

if dsn:
parsed_dsn = parse_dsn(dsn)
if "host" in parsed_dsn:
host = parsed_dsn["host"]
if "port" in parsed_dsn:
port = parsed_dsn["port"]

tunnel_info = urlparse(self.ssh_tunnel_url)
params = {
"local_bind_address": ("127.0.0.1",),
"remote_bind_address": (host, int(port or 5432)),
"ssh_address_or_host": (tunnel_info.hostname, tunnel_info.port or 22),
"logger": self.logger,
}
if tunnel_info.username:
params["ssh_username"] = tunnel_info.username
if tunnel_info.password:
params["ssh_password"] = tunnel_info.password

# Hack: sshtunnel adds a console handler to the logger, so we revert handlers.
# We can remove this when https://github.com/pahaz/sshtunnel/pull/250 is merged.
logger_handlers = self.logger.handlers.copy()
try:
self.ssh_tunnel = sshtunnel.SSHTunnelForwarder(**params)
self.ssh_tunnel.start()
except Exception as e:
self.logger.handlers = logger_handlers
self.logger.error("traceback: %r", traceback.format_exc())
click.secho(str(e), err=True, fg="red")
exit(1)
self.logger.handlers = logger_handlers

atexit.register(self.ssh_tunnel.stop)
host = "127.0.0.1"
port = self.ssh_tunnel.local_bind_ports[0]

if dsn:
dsn = make_dsn(dsn, host=host, port=port)

# Attempt to connect to the database.
# Note that passwd may be empty on the first attempt. If connection
# fails because of a missing or incorrect password, but we're allowed to
Expand Down Expand Up @@ -1222,7 +1280,7 @@ def echo_via_pager(self, text, color=None):
"--list",
"list_databases",
is_flag=True,
help="list " "available databases, then exit.",
help="list available databases, then exit.",
)
@click.option(
"--auto-vertical-output",
Expand All @@ -1235,6 +1293,11 @@ def echo_via_pager(self, text, color=None):
type=click.Choice(["all", "moderate", "off"]),
help="Warn before running a destructive query.",
)
@click.option(
"--ssh-tunnel",
default=None,
help="Open an SSH tunnel to the given address and connect to the database from it.",
)
@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
def cli(
Expand All @@ -1258,6 +1321,7 @@ def cli(
auto_vertical_output,
list_dsn,
warn,
ssh_tunnel: str,
):
if version:
print("Version:", __version__)
Expand Down Expand Up @@ -1294,6 +1358,15 @@ def cli(
)
exit(1)

if ssh_tunnel and not SSH_TUNNEL_SUPPORT:
click.secho(
'Cannot open SSH tunnel, "sshtunnel" package was not found. '
"Please install pgcli with `pip install pgcli[sshtunnel]` if you want SSH tunnel support.",
err=True,
fg="red",
)
exit(1)

pgcli = PGCli(
prompt_passwd,
never_prompt,
Expand All @@ -1305,6 +1378,7 @@ def cli(
prompt_dsn=prompt_dsn,
auto_vertical_output=auto_vertical_output,
warn=warn,
ssh_tunnel_url=ssh_tunnel,
)

# Choose which ever one has a valid value.
Expand Down Expand Up @@ -1548,7 +1622,7 @@ def parse_service_info(service):
elif os.getenv("PGSYSCONFDIR"):
service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
else:
service_file = expanduser("~/.pg_service.conf")
service_file = os.path.expanduser("~/.pg_service.conf")
if not service or not os.path.exists(service_file):
# nothing to do
return None, service_file
Expand Down
18 changes: 7 additions & 11 deletions pgcli/packages/parseutils/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,13 @@ def extract_from_part(parsed, stop_at_punctuation=True):
yield item
elif item.ttype is Keyword or item.ttype is Keyword.DML:
item_val = item.value.upper()
if (
item_val
in (
"COPY",
"FROM",
"INTO",
"UPDATE",
"TABLE",
)
or item_val.endswith("JOIN")
):
if item_val in (
"COPY",
"FROM",
"INTO",
"UPDATE",
"TABLE",
) or item_val.endswith("JOIN"):
tbl_prefix_seen = True
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
# So this check here is necessary.
Expand Down
13 changes: 8 additions & 5 deletions pgcli/pgcompleter.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,11 +491,14 @@ def get_completions(self, document, complete_event, smart_completion=None):

def get_column_matches(self, suggestion, word_before_cursor):
tables = suggestion.table_refs
do_qualify = suggestion.qualifiable and {
"always": True,
"never": False,
"if_more_than_one_table": len(tables) > 1,
}[self.qualify_columns]
do_qualify = (
suggestion.qualifiable
and {
"always": True,
"never": False,
"if_more_than_one_table": len(tables) > 1,
}[self.qualify_columns]
)
qualify = lambda col, tbl: (
(tbl + "." + self.case(col)) if do_qualify else self.case(col)
)
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@
description=description,
long_description=open("README.rst").read(),
install_requires=install_requirements,
extras_require={"keyring": ["keyring >= 12.2.0"]},
extras_require={
"keyring": ["keyring >= 12.2.0"],
"sshtunnel": ["sshtunnel >= 0.4.0"],
},
python_requires=">=3.6",
entry_points="""
[console_scripts]
Expand Down
4 changes: 2 additions & 2 deletions tests/features/steps/basic_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def step_see_error_message(context):
@when("we send source command")
def step_send_source_command(context):
context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_")
context.tmpfile_sql_help.write(br"\?")
context.tmpfile_sql_help.write(rb"\?")
context.tmpfile_sql_help.flush()
context.cli.sendline(fr"\i {context.tmpfile_sql_help.name}")
context.cli.sendline(rf"\i {context.tmpfile_sql_help.name}")
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)


Expand Down
Loading

0 comments on commit ed9d123

Please sign in to comment.