Skip to content

Commit

Permalink
Add bind_address option (#529)
Browse files Browse the repository at this point in the history
Allow connecting to the DB from a specific network interface
  • Loading branch information
dciabrin authored and methane committed Nov 16, 2016
1 parent 6c9d31a commit 755dfdc
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions pymysql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,8 @@ def __init__(self, host=None, user=None, password="",
compress=None, named_pipe=None, no_delay=None,
autocommit=False, db=None, passwd=None, local_infile=False,
max_allowed_packet=16*1024*1024, defer_connect=False,
auth_plugin_map={}, read_timeout=None, write_timeout=None):
auth_plugin_map={}, read_timeout=None, write_timeout=None,
bind_address=None):
"""
Establish a connection to the MySQL database. Accepts several
arguments:
Expand All @@ -544,6 +545,9 @@ def __init__(self, host=None, user=None, password="",
password: Password to use.
database: Database to use, None to not use a particular one.
port: MySQL port to use, default is usually OK. (default: 3306)
bind_address: When the client has multiple network interfaces, specify
the interface from which to connect to the host. Argument can be
a hostname or an IP address.
unix_socket: Optionally, you can use a unix socket rather than TCP/IP.
charset: Charset you want to use.
sql_mode: Default SQL_MODE to use.
Expand Down Expand Up @@ -632,6 +636,7 @@ def _config(key, arg):
database = _config("database", database)
unix_socket = _config("socket", unix_socket)
port = int(_config("port", port))
bind_address = _config("bind-address", bind_address)
charset = _config("default-character-set", charset)

self.host = host or "localhost"
Expand All @@ -640,6 +645,7 @@ def _config(key, arg):
self.password = password or ""
self.db = database
self.unix_socket = unix_socket
self.bind_address = bind_address
if read_timeout is not None and read_timeout <= 0:
raise ValueError("read_timeout should be >= 0")
self._read_timeout = read_timeout
Expand Down Expand Up @@ -884,10 +890,14 @@ def connect(self, sock=None):
self.host_info = "Localhost via UNIX socket"
if DEBUG: print('connected using unix_socket')
else:
kwargs = {}
if self.bind_address is not None:
kwargs['source_address'] = (self.bind_address, 0)
while True:
try:
sock = socket.create_connection(
(self.host, self.port), self.connect_timeout)
(self.host, self.port), self.connect_timeout,
**kwargs)
break
except (OSError, IOError) as e:
if e.errno == errno.EINTR:
Expand Down

0 comments on commit 755dfdc

Please sign in to comment.