strongswan/src/libcharon/plugins/vici/python/vici/protocol.py

207 lines
6.9 KiB
Python

import io
import socket
import struct
from collections import namedtuple
from collections import OrderedDict
from .compat import iteritems
from .exception import DeserializationException
class Transport(object):
HEADER_LENGTH = 4
MAX_SEGMENT = 512 * 1024
def __init__(self, sock):
self.socket = sock
def send(self, packet):
self.socket.sendall(struct.pack("!I", len(packet)) + packet)
def receive(self):
raw_length = self._recvall(self.HEADER_LENGTH)
length, = struct.unpack("!I", raw_length)
payload = self._recvall(length)
return payload
def close(self):
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
def _recvall(self, count):
"""Ensure to read count bytes from the socket"""
data = b""
while len(data) < count:
buf = self.socket.recv(count - len(data))
if not buf:
raise socket.error('Connection closed')
data += buf
return data
class Packet(object):
CMD_REQUEST = 0 # Named request message
CMD_RESPONSE = 1 # Unnamed response message for a request
CMD_UNKNOWN = 2 # Unnamed response if requested command is unknown
EVENT_REGISTER = 3 # Named event registration request
EVENT_UNREGISTER = 4 # Named event de-registration request
EVENT_CONFIRM = 5 # Unnamed confirmation for event (de-)registration
EVENT_UNKNOWN = 6 # Unnamed response if event (de-)registration failed
EVENT = 7 # Named event message
ParsedPacket = namedtuple(
"ParsedPacket",
["response_type", "payload"]
)
ParsedEventPacket = namedtuple(
"ParsedEventPacket",
["response_type", "event_type", "payload"]
)
@classmethod
def _named_request(cls, request_type, request, message=None):
request = request.encode("UTF-8")
payload = struct.pack("!BB", request_type, len(request)) + request
if message is not None:
return payload + message
else:
return payload
@classmethod
def request(cls, command, message=None):
return cls._named_request(cls.CMD_REQUEST, command, message)
@classmethod
def register_event(cls, event_type):
return cls._named_request(cls.EVENT_REGISTER, event_type)
@classmethod
def unregister_event(cls, event_type):
return cls._named_request(cls.EVENT_UNREGISTER, event_type)
@classmethod
def parse(cls, packet):
stream = FiniteStream(packet)
response_type, = struct.unpack("!B", stream.read(1))
if response_type == cls.EVENT:
length, = struct.unpack("!B", stream.read(1))
event_type = stream.read(length)
return cls.ParsedEventPacket(response_type, event_type, stream)
else:
return cls.ParsedPacket(response_type, stream)
class Message(object):
SECTION_START = 1 # Begin a new section having a name
SECTION_END = 2 # End a previously started section
KEY_VALUE = 3 # Define a value for a named key in the section
LIST_START = 4 # Begin a named list for list items
LIST_ITEM = 5 # Define an unnamed item value in the current list
LIST_END = 6 # End a previously started list
@classmethod
def serialize(cls, message):
def encode_named_type(marker, name):
name = name.encode("UTF-8")
return struct.pack("!BB", marker, len(name)) + name
def encode_blob(value):
if not isinstance(value, bytes):
value = str(value).encode("UTF-8")
return struct.pack("!H", len(value)) + value
def serialize_list(lst):
segment = bytes()
for item in lst:
segment += struct.pack("!B", cls.LIST_ITEM) + encode_blob(item)
return segment
def serialize_dict(d):
segment = bytes()
for key, value in iteritems(d):
if isinstance(value, dict):
segment += (
encode_named_type(cls.SECTION_START, key)
+ serialize_dict(value)
+ struct.pack("!B", cls.SECTION_END)
)
elif isinstance(value, list):
segment += (
encode_named_type(cls.LIST_START, key)
+ serialize_list(value)
+ struct.pack("!B", cls.LIST_END)
)
else:
segment += (
encode_named_type(cls.KEY_VALUE, key)
+ encode_blob(value)
)
return segment
return serialize_dict(message)
@classmethod
def deserialize(cls, stream):
def decode_named_type(stream):
length, = struct.unpack("!B", stream.read(1))
return stream.read(length).decode("UTF-8")
def decode_blob(stream):
length, = struct.unpack("!H", stream.read(2))
return stream.read(length)
def decode_list_item(stream):
marker, = struct.unpack("!B", stream.read(1))
while marker == cls.LIST_ITEM:
yield decode_blob(stream)
marker, = struct.unpack("!B", stream.read(1))
if marker != cls.LIST_END:
raise DeserializationException(
"Expected end of list at {pos}".format(pos=stream.tell())
)
section = OrderedDict()
section_stack = []
while stream.has_more():
element_type, = struct.unpack("!B", stream.read(1))
if element_type == cls.SECTION_START:
section_name = decode_named_type(stream)
new_section = OrderedDict()
section[section_name] = new_section
section_stack.append(section)
section = new_section
elif element_type == cls.LIST_START:
list_name = decode_named_type(stream)
section[list_name] = [item for item in decode_list_item(stream)]
elif element_type == cls.KEY_VALUE:
key = decode_named_type(stream)
section[key] = decode_blob(stream)
elif element_type == cls.SECTION_END:
if len(section_stack):
section = section_stack.pop()
else:
raise DeserializationException(
"Unexpected end of section at {pos}".format(
pos=stream.tell()
)
)
if len(section_stack):
raise DeserializationException("Expected end of section")
return section
class FiniteStream(io.BytesIO):
def __len__(self):
return len(self.getvalue())
def has_more(self):
return self.tell() < len(self)