-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
payload.py
431 lines (392 loc) · 16.6 KB
/
payload.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
# -*- coding: utf-8 -*-
'''
Many aspects of the salt payload need to be managed, from the return of
encrypted keys to general payload dynamics and packaging, these happen
in here
'''
# Import python libs
from __future__ import absolute_import, print_function, unicode_literals
# import sys # Use if sys is commented out below
import logging
import gc
import datetime
# Import salt libs
import salt.log
import salt.transport.frame
import salt.utils.immutabletypes as immutabletypes
import salt.utils.stringutils
from salt.exceptions import SaltReqTimeoutError
from salt.utils.data import CaseInsensitiveDict
# Import third party libs
from salt.ext import six
try:
import zmq
except ImportError:
# No need for zeromq in local mode
pass
log = logging.getLogger(__name__)
HAS_MSGPACK = False
try:
# Attempt to import msgpack
import msgpack
# There is a serialization issue on ARM and potentially other platforms
# for some msgpack bindings, check for it
if msgpack.version >= (0, 4, 0):
if msgpack.loads(msgpack.dumps([1, 2, 3], use_bin_type=False), use_list=True) is None:
raise ImportError
else:
if msgpack.loads(msgpack.dumps([1, 2, 3]), use_list=True) is None:
raise ImportError
HAS_MSGPACK = True
except ImportError:
# Fall back to msgpack_pure
try:
import msgpack_pure as msgpack # pylint: disable=import-error
HAS_MSGPACK = True
except ImportError:
# TODO: Come up with a sane way to get a configured logfile
# and write to the logfile when this error is hit also
LOG_FORMAT = '[%(levelname)-8s] %(message)s'
salt.log.setup_console_logger(log_format=LOG_FORMAT)
log.fatal('Unable to import msgpack or msgpack_pure python modules')
# Don't exit if msgpack is not available, this is to make local mode
# work without msgpack
#sys.exit(salt.defaults.exitcodes.EX_GENERIC)
if HAS_MSGPACK:
import salt.utils.msgpack
if HAS_MSGPACK and not hasattr(msgpack, 'exceptions'):
class PackValueError(Exception):
'''
older versions of msgpack do not have PackValueError
'''
class exceptions(object):
'''
older versions of msgpack do not have an exceptions module
'''
PackValueError = PackValueError()
msgpack.exceptions = exceptions()
def package(payload):
'''
This method for now just wraps msgpack.dumps, but it is here so that
we can make the serialization a custom option in the future with ease.
'''
return salt.utils.msgpack.dumps(payload, _msgpack_module=msgpack)
def unpackage(package_):
'''
Unpackages a payload
'''
return salt.utils.msgpack.loads(package_, use_list=True,
_msgpack_module=msgpack)
def format_payload(enc, **kwargs):
'''
Pass in the required arguments for a payload, the enc type and the cmd,
then a list of keyword args to generate the body of the load dict.
'''
payload = {'enc': enc}
load = {}
for key in kwargs:
load[key] = kwargs[key]
payload['load'] = load
return package(payload)
class Serial(object):
'''
Create a serialization object, this object manages all message
serialization in Salt
'''
def __init__(self, opts):
if isinstance(opts, dict):
self.serial = opts.get('serial', 'msgpack')
elif isinstance(opts, six.string_types):
self.serial = opts
else:
self.serial = 'msgpack'
def loads(self, msg, encoding=None, raw=False):
'''
Run the correct loads serialization format
:param encoding: Useful for Python 3 support. If the msgpack data
was encoded using "use_bin_type=True", this will
differentiate between the 'bytes' type and the
'str' type by decoding contents with 'str' type
to what the encoding was set as. Recommended
encoding is 'utf-8' when using Python 3.
If the msgpack data was not encoded using
"use_bin_type=True", it will try to decode
all 'bytes' and 'str' data (the distinction has
been lost in this case) to what the encoding is
set as. In this case, it will fail if any of
the contents cannot be converted.
'''
try:
def ext_type_decoder(code, data):
if code == 78:
data = salt.utils.stringutils.to_unicode(data)
return datetime.datetime.strptime(data, '%Y%m%dT%H:%M:%S.%f')
return data
gc.disable() # performance optimization for msgpack
loads_kwargs = {'use_list': True,
'ext_hook': ext_type_decoder,
'_msgpack_module': msgpack}
if msgpack.version >= (0, 4, 0):
# msgpack only supports 'encoding' starting in 0.4.0.
# Due to this, if we don't need it, don't pass it at all so
# that under Python 2 we can still work with older versions
# of msgpack.
if msgpack.version >= (0, 5, 2):
if encoding is None:
loads_kwargs['raw'] = True
else:
loads_kwargs['raw'] = False
else:
loads_kwargs['encoding'] = encoding
try:
ret = salt.utils.msgpack.loads(msg, **loads_kwargs)
except UnicodeDecodeError:
# msg contains binary data
loads_kwargs.pop('raw', None)
loads_kwargs.pop('encoding', None)
ret = salt.utils.msgpack.loads(msg, **loads_kwargs)
else:
ret = salt.utils.msgpack.loads(msg, **loads_kwargs)
if six.PY3 and encoding is None and not raw:
ret = salt.transport.frame.decode_embedded_strs(ret)
except Exception as exc:
log.critical(
'Could not deserialize msgpack message. This often happens '
'when trying to read a file not in binary mode. '
'To see message payload, enable debug logging and retry. '
'Exception: %s', exc
)
log.debug('Msgpack deserialization failure on message: %s', msg)
gc.collect()
raise
finally:
gc.enable()
return ret
def load(self, fn_):
'''
Run the correct serialization to load a file
'''
data = fn_.read()
fn_.close()
if data:
if six.PY3:
return self.loads(data, encoding='utf-8')
else:
return self.loads(data)
def dumps(self, msg, use_bin_type=False):
'''
Run the correct dumps serialization format
:param use_bin_type: Useful for Python 3 support. Tells msgpack to
differentiate between 'str' and 'bytes' types
by encoding them differently.
Since this changes the wire protocol, this
option should not be used outside of IPC.
'''
def ext_type_encoder(obj):
if isinstance(obj, six.integer_types):
# msgpack can't handle the very long Python longs for jids
# Convert any very long longs to strings
return six.text_type(obj)
elif isinstance(obj, (datetime.datetime, datetime.date)):
# msgpack doesn't support datetime.datetime and datetime.date datatypes.
# So here we have converted these types to custom datatype
# This is msgpack Extended types numbered 78
return msgpack.ExtType(78, salt.utils.stringutils.to_bytes(
obj.strftime('%Y%m%dT%H:%M:%S.%f')))
# The same for immutable types
elif isinstance(obj, immutabletypes.ImmutableDict):
return dict(obj)
elif isinstance(obj, immutabletypes.ImmutableList):
return list(obj)
elif isinstance(obj, (set, immutabletypes.ImmutableSet)):
# msgpack can't handle set so translate it to tuple
return tuple(obj)
elif isinstance(obj, CaseInsensitiveDict):
return dict(obj)
# Nothing known exceptions found. Let msgpack raise it's own.
return obj
try:
if msgpack.version >= (0, 4, 0):
# msgpack only supports 'use_bin_type' starting in 0.4.0.
# Due to this, if we don't need it, don't pass it at all so
# that under Python 2 we can still work with older versions
# of msgpack.
return salt.utils.msgpack.dumps(msg, default=ext_type_encoder,
use_bin_type=use_bin_type,
_msgpack_module=msgpack)
else:
return salt.utils.msgpack.dumps(msg, default=ext_type_encoder,
_msgpack_module=msgpack)
except (OverflowError, msgpack.exceptions.PackValueError):
# msgpack<=0.4.6 don't call ext encoder on very long integers raising the error instead.
# Convert any very long longs to strings and call dumps again.
def verylong_encoder(obj, context):
# Make sure we catch recursion here.
objid = id(obj)
# This instance list needs to correspond to the types recursed
# in the below if/elif chain. Also update
# tests/unit/test_payload.py
if objid in context and isinstance(obj, (dict, list, tuple)):
return '<Recursion on {} with id={}>'.format(type(obj).__name__, id(obj))
context.add(objid)
# The isinstance checks in this if/elif chain need to be
# kept in sync with the above recursion check.
if isinstance(obj, dict):
for key, value in six.iteritems(obj.copy()):
obj[key] = verylong_encoder(value, context)
return dict(obj)
elif isinstance(obj, (list, tuple)):
obj = list(obj)
for idx, entry in enumerate(obj):
obj[idx] = verylong_encoder(entry, context)
return obj
# A value of an Integer object is limited from -(2^63) upto (2^64)-1 by MessagePack
# spec. Here we care only of JIDs that are positive integers.
if isinstance(obj, six.integer_types) and obj >= pow(2, 64):
return six.text_type(obj)
else:
return obj
msg = verylong_encoder(msg, set())
if msgpack.version >= (0, 4, 0):
return salt.utils.msgpack.dumps(msg, default=ext_type_encoder,
use_bin_type=use_bin_type,
_msgpack_module=msgpack)
else:
return salt.utils.msgpack.dumps(msg, default=ext_type_encoder,
_msgpack_module=msgpack)
def dump(self, msg, fn_):
'''
Serialize the correct data into the named file object
'''
if six.PY2:
fn_.write(self.dumps(msg))
else:
# When using Python 3, write files in such a way
# that the 'bytes' and 'str' types are distinguishable
# by using "use_bin_type=True".
fn_.write(self.dumps(msg, use_bin_type=True))
fn_.close()
class SREQ(object):
'''
Create a generic interface to wrap salt zeromq req calls.
'''
def __init__(self, master, id_='', serial='msgpack', linger=0, opts=None):
self.master = master
self.id_ = id_
self.serial = Serial(serial)
self.linger = linger
self.context = zmq.Context()
self.poller = zmq.Poller()
self.opts = opts
@property
def socket(self):
'''
Lazily create the socket.
'''
if not hasattr(self, '_socket'):
# create a new one
self._socket = self.context.socket(zmq.REQ)
if hasattr(zmq, 'RECONNECT_IVL_MAX'):
self._socket.setsockopt(
zmq.RECONNECT_IVL_MAX, 5000
)
self._set_tcp_keepalive()
if self.master.startswith('tcp://['):
# Hint PF type if bracket enclosed IPv6 address
if hasattr(zmq, 'IPV6'):
self._socket.setsockopt(zmq.IPV6, 1)
elif hasattr(zmq, 'IPV4ONLY'):
self._socket.setsockopt(zmq.IPV4ONLY, 0)
self._socket.linger = self.linger
if self.id_:
self._socket.setsockopt(zmq.IDENTITY, self.id_)
self._socket.connect(self.master)
return self._socket
def _set_tcp_keepalive(self):
if hasattr(zmq, 'TCP_KEEPALIVE') and self.opts:
if 'tcp_keepalive' in self.opts:
self._socket.setsockopt(
zmq.TCP_KEEPALIVE, self.opts['tcp_keepalive']
)
if 'tcp_keepalive_idle' in self.opts:
self._socket.setsockopt(
zmq.TCP_KEEPALIVE_IDLE, self.opts['tcp_keepalive_idle']
)
if 'tcp_keepalive_cnt' in self.opts:
self._socket.setsockopt(
zmq.TCP_KEEPALIVE_CNT, self.opts['tcp_keepalive_cnt']
)
if 'tcp_keepalive_intvl' in self.opts:
self._socket.setsockopt(
zmq.TCP_KEEPALIVE_INTVL, self.opts['tcp_keepalive_intvl']
)
def clear_socket(self):
'''
delete socket if you have it
'''
if hasattr(self, '_socket'):
if isinstance(self.poller.sockets, dict):
sockets = list(self.poller.sockets.keys())
for socket in sockets:
log.trace('Unregistering socket: %s', socket)
self.poller.unregister(socket)
else:
for socket in self.poller.sockets:
log.trace('Unregistering socket: %s', socket)
self.poller.unregister(socket[0])
del self._socket
def send(self, enc, load, tries=1, timeout=60):
'''
Takes two arguments, the encryption type and the base payload
'''
payload = {'enc': enc}
payload['load'] = load
pkg = self.serial.dumps(payload)
self.socket.send(pkg)
self.poller.register(self.socket, zmq.POLLIN)
tried = 0
while True:
polled = self.poller.poll(timeout * 1000)
tried += 1
if polled:
break
if tries > 1:
log.info(
'SaltReqTimeoutError: after %s seconds. (Try %s of %s)',
timeout, tried, tries
)
if tried >= tries:
self.clear_socket()
raise SaltReqTimeoutError(
'SaltReqTimeoutError: after {0} seconds, ran {1} '
'tries'.format(timeout * tried, tried)
)
return self.serial.loads(self.socket.recv())
def send_auto(self, payload, tries=1, timeout=60):
'''
Detect the encryption type based on the payload
'''
enc = payload.get('enc', 'clear')
load = payload.get('load', {})
return self.send(enc, load, tries, timeout)
def destroy(self):
if isinstance(self.poller.sockets, dict):
sockets = list(self.poller.sockets.keys())
for socket in sockets:
if socket.closed is False:
socket.setsockopt(zmq.LINGER, 1)
socket.close()
self.poller.unregister(socket)
else:
for socket in self.poller.sockets:
if socket[0].closed is False:
socket[0].setsockopt(zmq.LINGER, 1)
socket[0].close()
self.poller.unregister(socket[0])
if self.socket.closed is False:
self.socket.setsockopt(zmq.LINGER, 1)
self.socket.close()
if self.context.closed is False:
self.context.term()
def __del__(self):
self.destroy()