Skip to content

Commit

Permalink
refactor connection's sequence id
Browse files Browse the repository at this point in the history
  • Loading branch information
methane committed Jan 3, 2016
1 parent 0cffd55 commit 344b933
Showing 1 changed file with 26 additions and 36 deletions.
62 changes: 26 additions & 36 deletions pymysql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,6 @@ def __del__(self):
def autocommit(self, value):
self.autocommit_mode = bool(value)
current = self.get_autocommit()
self.next_packet = 1
if value != current:
self._send_autocommit_mode()

Expand Down Expand Up @@ -816,15 +815,13 @@ def query(self, sql, unbuffered=False):
"You may not close previous cursor.")
# if DEBUG:
# print("DEBUG: sending query:", sql)
self.next_packet = 1
if isinstance(sql, text_type) and not (JYTHON or IRONPYTHON):
if PY2:
sql = sql.encode(self.encoding)
else:
sql = sql.encode(self.encoding, 'surrogateescape')
self._execute_command(COMMAND.COM_QUERY, sql)
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
self.next_packet = 1
return self._affected_rows

def next_result(self, unbuffered=False):
Expand Down Expand Up @@ -892,7 +889,7 @@ def connect(self, sock=None):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self.socket = sock
self._rfile = _makefile(sock, 'rb')
self.next_packet = 0
self._next_seq_id = 0

self._get_server_information()
self._request_authentication()
Expand Down Expand Up @@ -933,16 +930,16 @@ def connect(self, sock=None):
# So just reraise it.
raise

def write_packet(self, data):
def write_packet(self, payload):
"""Writes an entire "mysql packet" in its entirety to the network
addings its length and sequence number. Intended for use by plugins
only.
addings its length and sequence number.
"""
data = pack_int24(len(data)) + int2byte(self.next_packet) + data
# Internal note: when you build packet manualy and calls _write_bytes()
# directly, you should set self._next_seq_id properly.
data = pack_int24(len(payload)) + int2byte(self._next_seq_id) + payload
if DEBUG: dump_packet(data)

self._write_bytes(data)
self.next_packet = (self.next_packet + 1) % 256
self._next_seq_id = (self._next_seq_id + 1) % 256

def _read_packet(self, packet_type=MysqlPacket):
"""Read an entire "mysql packet" in its entirety from the network
Expand All @@ -952,8 +949,14 @@ def _read_packet(self, packet_type=MysqlPacket):
while True:
packet_header = self._read_bytes(4)
if DEBUG: dump_packet(packet_header)

btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
bytes_to_read = btrl + (btrh << 16)
if packet_number != self._next_seq_id:
raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
(packet_number, self._next_seq_id))
self._next_seq_id = (self._next_seq_id + 1) % 256

recv_data = self._read_bytes(bytes_to_read)
if DEBUG: dump_packet(recv_data)
buff += recv_data
Expand All @@ -962,13 +965,7 @@ def _read_packet(self, packet_type=MysqlPacket):
continue
if bytes_to_read < MAX_PACKET_LEN:
break
if packet_number != self.next_packet:
pass
#TODO: check sequence id
#raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
# (packet_number, self.next_packet))

self.next_packet = (packet_number + 1) % 256
packet = packet_type(buff, self.encoding)
packet.check_error()
return packet
Expand Down Expand Up @@ -1027,33 +1024,32 @@ def _execute_command(self, command, sql):
if self._result is not None and self._result.unbuffered_active:
warnings.warn("Previous unbuffered result was left incomplete")
self._result._finish_unbuffered_query()
self._result = None

if isinstance(sql, text_type):
sql = sql.encode(self.encoding)

chunk_size = min(self.max_allowed_packet, len(sql) + 1) # +1 is for command
# +1 is for command
chunk_size = min(self.max_allowed_packet, len(sql) + 1)

# tiny optimization: build first packet manually instead of
# calling self..write_packet()
prelude = struct.pack('<iB', chunk_size, command)
self._write_bytes(prelude + sql[:chunk_size-1])
if DEBUG: dump_packet(prelude + sql)
packet = prelude + sql[:chunk_size-1]
self._write_bytes(packet)
if DEBUG: dump_packet(packet)
self._next_seq_id = 1

self.next_packet = 1
if chunk_size < self.max_allowed_packet:
return

seq_id = 1
sql = sql[chunk_size-1:]
while True:
chunk_size = min(self.max_allowed_packet, len(sql))
prelude = struct.pack('<i', chunk_size)[:3]
data = prelude + int2byte(seq_id%256) + sql[:chunk_size]
self._write_bytes(data)
if DEBUG: dump_packet(data)
self.write_packet(sql[:chunk_size])
sql = sql[chunk_size:]
if not sql and chunk_size < self.max_allowed_packet:
break
seq_id += 1
self.next_packet = seq_id%256

def _request_authentication(self):
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
Expand Down Expand Up @@ -1448,9 +1444,8 @@ def send_data(self):
"""Send data packets from the local file to the server"""
if not self.connection.socket:
raise err.InterfaceError("(0, '')")
conn = self.connection

# sequence id is 2 as we already sent a query packet
seq_id = 2
try:
with open(self.filename, 'rb') as open_file:
chunk_size = self.connection.max_allowed_packet
Expand All @@ -1460,14 +1455,9 @@ def send_data(self):
chunk = open_file.read(chunk_size)
if not chunk:
break
packet = struct.pack('<i', len(chunk))[:3] + int2byte(seq_id)
format_str = '!{0}s'.format(len(chunk))
packet += struct.pack(format_str, chunk)
self.connection._write_bytes(packet)
seq_id = (seq_id + 1) % 256
conn.write_packet(chunk)
except IOError:
raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
finally:
# send the empty packet to signify we are done sending data
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
self.connection._write_bytes(packet)
conn.write_packet(b'')

0 comments on commit 344b933

Please sign in to comment.