# -*- coding: utf-8 -*- ''' Very simple (performance oriented) declarative message codec. Inspired by Pycrate and Scapy. ''' # TRX Toolkit # # (C) 2021 by sysmocom - s.f.m.c. GmbH # Author: Vadim Yanitskiy # # All Rights Reserved # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation; either version 2 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. from typing import Optional, Callable, Tuple, Any import abc class ProtocolError(Exception): ''' Error in a protocol definition. ''' class DecodeError(Exception): ''' Error during decoding of a field/message. ''' class EncodeError(Exception): ''' Error during encoding of a field/message. ''' class Codec(abc.ABC): ''' Base class providing encoding and decoding API. ''' @abc.abstractmethod def from_bytes(self, vals: dict, data: bytes) -> int: ''' Decode value(s) from the given buffer of bytes. ''' @abc.abstractmethod def to_bytes(self, vals: dict) -> bytes: ''' Encode value(s) into bytes. ''' class Field(Codec): ''' Base class representing one field in a Message. ''' # Default length (0 means the whole buffer) DEF_LEN = 0 # type: int # Default parameters DEF_PARAMS = { } # type: dict # Presence of a field during decoding and encoding ## get_pres: Callable[[dict], bool] # Length of a field for self.from_bytes() ## get_len: Callable[[dict, bytes], int] # Value of a field for self.to_bytes() ## get_val: Callable[[dict], Any] def __init__(self, name: str, **kw) -> None: self.name = name self.len = kw.get('len', self.DEF_LEN) if self.len == 0: # flexible field self.get_len = lambda _, data: len(data) else: # fixed length self.get_len = lambda vals, _: self.len # Field is unconditionally present by default self.get_pres = lambda vals: True # Field takes its value from the given dict by default self.get_val = lambda vals: vals[self.name] # Additional parameters for derived field types self.p = { key : kw.get(key, self.DEF_PARAMS[key]) for key in self.DEF_PARAMS } def from_bytes(self, vals: dict, data: bytes) -> int: if self.get_pres(vals) is False: return 0 length = self.get_len(vals, data) if len(data) < length: raise DecodeError('Short read') self._from_bytes(vals, data[:length]) return length def to_bytes(self, vals: dict) -> bytes: if self.get_pres(vals) is False: return b'' data = self._to_bytes(vals) if self.len > 0 and len(data) != self.len: raise EncodeError('Field length mismatch') return data @abc.abstractmethod def _from_bytes(self, vals: dict, data: bytes) -> None: ''' Decode value(s) from the given buffer of bytes. ''' raise NotImplementedError @abc.abstractmethod def _to_bytes(self, vals: dict) -> bytes: ''' Encode value(s) into bytes. ''' raise NotImplementedError class Buf(Field): ''' A sequence of octets. ''' def _from_bytes(self, vals: dict, data: bytes) -> None: vals[self.name] = data def _to_bytes(self, vals: dict) -> bytes: # TODO: handle len(self.get_val()) < self.get_len() return self.get_val(vals) class Spare(Field): ''' Spare filling for RFU fields or padding. ''' # Default parameters DEF_PARAMS = { 'filler' : b'\x00', } def _from_bytes(self, vals: dict, data: bytes) -> None: pass # Just ignore it def _to_bytes(self, vals: dict) -> bytes: return self.p['filler'] * self.get_len(vals, b'') class Uint(Field): ''' An integer field: unsigned, N bits, big endian. ''' # Uint8 by default DEF_LEN = 1 # Default parameters DEF_PARAMS = { 'offset' : 0, 'mult' : 1, } # Big endian, unsigned SIGN = False BO = 'big' def _from_bytes(self, vals: dict, data: bytes) -> None: val = int.from_bytes(data, self.BO, signed=self.SIGN) vals[self.name] = val * self.p['mult'] + self.p['offset'] def _to_bytes(self, vals: dict) -> bytes: val = (self.get_val(vals) - self.p['offset']) // self.p['mult'] return val.to_bytes(self.len, self.BO, signed=self.SIGN) class Uint16BE(Uint): DEF_LEN = 16 // 8 class Uint16LE(Uint16BE): BO = 'little' class Uint32BE(Uint): DEF_LEN = 32 // 8 class Uint32LE(Uint32BE): BO = 'little' class Int(Uint): SIGN = True class Int16BE(Int): DEF_LEN = 16 // 8 class Int16LE(Int16BE): BO = 'little' class Int32BE(Int): DEF_LEN = 32 // 8 class Int32LE(Int32BE): BO = 'little' class BitFieldSet(Field): ''' A set of bit-fields. ''' # Default parameters DEF_PARAMS = { # Default field order (MSB first) 'order' : 'big', } # To be defined by derived types STRUCT = () # type: Tuple['BitField', ...] def __init__(self, **kw) -> None: Field.__init__(self, self.__class__.__name__, **kw) self._fields = kw.get('set', self.STRUCT) if type(self._fields) is not tuple: raise ProtocolError('Expected a tuple') # LSB first is basically reversed order if self.p['order'] in ('little', 'lsb'): self._fields = self._fields[::-1] # Calculate the overall field length if self.len == 0: bl_sum = sum([f.bl for f in self._fields]) self.len = bl_sum // 8 if bl_sum % 8 > 0: self.len += 1 # Re-define self.get_len() since we always know the length self.get_len = lambda vals, data: self.len # Pre-calculate offset and mask for each field offset = self.len * 8 for f in self._fields: if f.bl > offset: raise ProtocolError(f, 'BitFieldSet overflow') f.offset = offset - f.bl f.mask = 2 ** f.bl - 1 offset -= f.bl def _from_bytes(self, vals: dict, data: bytes) -> None: blob = int.from_bytes(data, byteorder='big') # intentionally using 'big' here for f in self._fields: f.dec_val(vals, blob) def _to_bytes(self, vals: dict) -> bytes: blob = 0x00 for f in self._fields: # TODO: use functools.reduce()? blob |= f.enc_val(vals) return blob.to_bytes(self.len, byteorder='big') class BitField: ''' One field in a BitFieldSet. ''' # Special fields for BitFieldSet offset = 0 # type: int mask = 0 # type: int class Spare: ''' Spare filling in a BitFieldSet. ''' def __init__(self, bl: int) -> None: self.name = None self.bl = bl def enc_val(self, vals: dict) -> int: return 0 def dec_val(self, vals: dict, blob: int) -> None: pass # Just ignore it def __init__(self, name: str, bl: int, **kw) -> None: if bl < 1: # Ensure proper length raise ProtocolError('Incorrect bit-field length') self.name = name self.bl = bl # (Optional) fixed value for encoding and decoding self.val = kw.get('val', None) # type: Optional[int] def enc_val(self, vals: dict) -> int: if self.val is None: val = vals[self.name] else: val = self.val return (val & self.mask) << self.offset def dec_val(self, vals: dict, blob: int) -> None: vals[self.name] = (blob >> self.offset) & self.mask if (self.val is not None) and (vals[self.name] != self.val): raise DecodeError('Unexpected value %d, expected %d' % (vals[self.name], self.val)) class Envelope: ''' A group of related fields. ''' STRUCT = () # type: Tuple[Codec, ...] def __init__(self, check_len: bool = True): # TODO: ensure uniqueue field names in self.STRUCT self.c = { } # type: dict self.check_len = check_len def __getitem__(self, key: str) -> Any: return self.c[key] def __setitem__(self, key: str, val: Any) -> None: self.c[key] = val def __delitem__(self, key: str) -> None: del self.c[key] def check(self, vals: dict) -> None: ''' Check the content before encoding and after decoding. Raise exceptions (e.g. ValueError) if something is wrong. Do not assert for every possible error (e.g. a negative value for a Uint field) if an exception will be thrown by the field's to_bytes() method anyway. Only additional constraints here. ''' def from_bytes(self, data: bytes) -> int: self.c.clear() # forget the old content return self._from_bytes(self.c, data) def to_bytes(self) -> bytes: return self._to_bytes(self.c) def _from_bytes(self, vals: dict, data: bytes, offset: int = 0) -> int: try: # Fields throw exceptions for f in self.STRUCT: offset += f.from_bytes(vals, data[offset:]) except Exception as e: # Add contextual info raise DecodeError(self, f, offset) from e if self.check_len and len(data) != offset: raise DecodeError(self, 'Unhandled tail octets: %s' % data[offset:].hex()) self.check(vals) # Check the content after decoding (raises exceptions) return offset def _to_bytes(self, vals: dict) -> bytes: def proc(f: Codec): try: # Fields throw exceptions return f.to_bytes(vals) except Exception as e: # Add contextual info raise EncodeError(self, f) from e self.check(vals) # Check the content before encoding (raises exceptions) return b''.join([proc(f) for f in self.STRUCT]) class F(Field): ''' Field wrapper. ''' def __init__(self, e: 'Envelope', name: str, **kw) -> None: Field.__init__(self, name, **kw) self.e = e def _from_bytes(self, vals: dict, data: bytes) -> None: vals[self.name] = { } self.e._from_bytes(vals[self.name], data) def _to_bytes(self, vals: dict) -> bytes: return self.e._to_bytes(self.get_val(vals)) def f(self, name: str, **kw) -> Field: return self.F(self, name, **kw) class Sequence: ''' A sequence of repeating elements (e.g. TLVs). ''' # The item of sequence ITEM = None # type: Optional[Envelope] def __init__(self, **kw) -> None: if (self.ITEM is None) and ('item' not in kw): raise ProtocolError('Missing Sequence item') self._item = kw.get('item', self.ITEM) # type: Envelope self._item.check_len = False def from_bytes(self, data: bytes) -> list: proc = self._item._from_bytes vseq, offset = [], 0 length = len(data) while offset < length: vseq.append({ }) # new item of sequence offset += proc(vseq[-1], data[offset:]) return vseq def to_bytes(self, vseq: list) -> bytes: proc = self._item._to_bytes return b''.join([proc(v) for v in vseq]) class F(Field): ''' Field wrapper. ''' def __init__(self, s: 'Sequence', name: str, **kw) -> None: Field.__init__(self, name, **kw) self.s = s def _from_bytes(self, vals: dict, data: bytes) -> None: vals[self.name] = self.s.from_bytes(data) def _to_bytes(self, vals: dict) -> bytes: return self.s.to_bytes(self.get_val(vals)) def f(self, name: str, **kw) -> Field: return self.F(self, name, **kw)