123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306 |
- """
- Provides functionality for validation of the data-types specified
- for odml
- """
- import sys
- self = sys.modules[__name__].__dict__
- import datetime
- import binascii
- import hashlib
- from enum import Enum
- class DType(str, Enum):
- string = 'string'
- text = 'text'
- int = 'int'
- float = 'float'
- url = 'url'
- datetime = 'datetime'
- date = 'date'
- time = 'time'
- boolean = 'boolean'
- person = 'person'
- binary = 'binary'
- def __str__(self):
- return self.name
- _dtype_map = {'str': 'string', 'bool': 'boolean'}
- def infer_dtype(value):
- dtype = (type(value)).__name__
- if dtype in _dtype_map:
- dtype = _dtype_map[dtype]
- if valid_type(dtype):
- if dtype == 'string' and '\n' in value:
- dtype = 'text'
- return dtype
- else:
- return None
- def valid_type(dtype):
- """
- checks if *dtype* is a valid type
- """
- dtype = dtype.lower()
- if dtype in _dtype_map:
- dtype = _dtype_map[dtype]
- if hasattr(DType, dtype):
- return True
- if dtype is None:
- return True
- if dtype.endswith("-tuple"):
- try:
- int(dtype[:-6])
- return True
- except ValueError:
- pass
- return False
- # TODO also take encoding into account
- def validate(string, dtype):
- """
- checks if:
- * *dtype* is a valid type
- * *string* is a valid expression of type *dtype*
- """
- try:
- if not valid_type(dtype):
- if dtype.endswith("-tuple"):
- count = int(dtype[:-6])
- #try to parse it
- tuple_get(string, count=count)
- return True
- #try to parse it
- self.get(dtype + "_get", str_get)(string)
- else:
- return False
- except RuntimeError:
- #any error, this type ain't valid
- return False
- def get(string, dtype=None, encoding=None):
- """
- convert *string* to the corresponding *dtype*
- """
- if not dtype: return str_get(string)
- if dtype.endswith("-tuple"): # special case, as the count-number is included in the type-name
- return tuple_get(string)
- if dtype == "binary":
- return binary_get(string, encoding)
- return self.get(dtype + "_get", str_get)(string)
- def set(value, dtype=None, encoding=None):
- """
- serialize a *value* of type *dtype* to a unicode string
- """
- if not dtype:
- return str_set(value)
- if dtype.endswith("-tuple"):
- return tuple_set(value)
- if dtype == "binary":
- return binary_set(value, encoding)
- if sys.version_info > (3, 0):
- if isinstance(value, str):
- return str_set(value)
- else:
- if type(value) in (str, unicode):
- return str_set(value)
- return self.get(dtype + "_set", str_set)(value)
- def int_get(string):
- if not string: return 0
- try:
- return int(string)
- except ValueError:
- # convert to float first and then cast to int
- return int(float(string))
- def float_get(string):
- if not string: return 0.0
- return float(string)
- def str_get(string):
- if sys.version_info < (3, 0):
- return unicode(string)
- return str(string)
- def str_set(value):
- try:
- if sys.version_info < (3, 0):
- return unicode(value)
- else:
- return str(value)
- except Exception as ex:
- fail = ex
- raise fail
- def time_get(string):
- if not string: return None
- if type(string) is datetime.time:
- return string.strftime('%H:%M:%S').time()
- else:
- return datetime.datetime.strptime(string, '%H:%M:%S').time()
- def time_set(value):
- if not value: return None
- if type(value) is datetime.time:
- return value.strftime("%H:%M:%S")
- return value.isoformat()
- def date_get(string):
- if not string: return None
- if type(string) is datetime.date:
- return datetime.datetime.strptime(string.isoformat(), '%Y-%m-%d').date()
- else:
- return datetime.datetime.strptime(string, '%Y-%m-%d').date()
- date_set = time_set
- def datetime_get(string):
- if not string: return None
- if type(string) is datetime.datetime:
- return string.strftime('%Y-%m-%d %H:%M:%S')
- else:
- return datetime.datetime.strptime(string, '%Y-%m-%d %H:%M:%S')
- def datetime_set(value):
- if not value: return None
- if type(value) is datetime.datetime:
- return value.strftime('%Y-%m-%d %H:%M:%S')
- else:
- return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
- def boolean_get(string):
- if not string: return None
- if type(string) is bool:
- string = str(string)
- string = string.lower()
- truth = ["true", "t", "1"] # be kind, spec only accepts True / False
- if string in truth: return True
- false = ["false", "f", "0"]
- if string in false: return False
- raise ValueError("Cannot interpret '%s' as boolean" % string)
- def boolean_set(value):
- if value is None: return None
- return str(value)
- def tuple_get(string, count=None):
- """
- parse a tuple string like "(1024;768)" and return strings of the elements
- """
- if not string: return None
- string = string.strip()
- assert string.startswith("(") and string.endswith(")")
- string = string[1:-1]
- res = string.split(";")
- if count is not None: # be strict
- assert len(res) == count
- return res
- def tuple_set(value):
- if not value: return None
- return "(%s)" % ";".join(value)
- ###############################################################################
- # Binary Encoding Stuff
- ###############################################################################
- class Encoder(object):
- def __init__(self, encode, decode):
- self._encode = encode
- self._decode = decode
- def encode(self, data):
- if sys.version_info > (3, 0) and isinstance(data, str):
- data = str.encode(data)
- return self._encode(data)
- def decode(self, string):
- return self._decode(string)
- encodings = {
- 'base64': Encoder(lambda x: binascii.b2a_base64(x).strip(), binascii.a2b_base64),
- 'quoted-printable': Encoder(binascii.b2a_qp, binascii.a2b_qp),
- 'hexadecimal': Encoder(binascii.b2a_hex, binascii.a2b_hex),
- None: Encoder(lambda x: x, lambda x: x), #identity encoder
- }
- def valid_encoding(encoding):
- return encoding in encodings
- def binary_get(string, encoding=None):
- "binary decode the *string* according to *encoding*"
- if not string: return None
- return encodings[encoding].decode(string)
- def binary_set(value, encoding=None):
- "binary encode the *value* according to *encoding*"
- if not value: return None
- return encodings[encoding].encode(value)
- def calculate_crc32_checksum(data):
- if sys.version_info < (3, 0):
- return "%08x" % (binascii.crc32(data) & 0xffffffff)
- else:
- if isinstance(data, str):
- data = str.encode(data)
- return "%08x" % (binascii.crc32(data) & 0xffffffff)
- checksums = {
- 'crc32': calculate_crc32_checksum,
- }
- # allow to use any available algorithm
- if sys.version_info > (3, 0):
- for algo in hashlib.algorithms_guaranteed:
- checksums[algo] = lambda data, func=getattr(hashlib, algo): func(data).hexdigest()
- elif not sys.version_info < (2, 7):
- for algo in hashlib.algorithms:
- checksums[algo] = lambda data, func=getattr(hashlib, algo): func(data).hexdigest()
- def valid_checksum_type(checksum_type):
- return checksum_type in checksums
- def calculate_checksum(data, checksum_type):
- if data is None: data = ''
- if isinstance(data, str):
- data = str.encode(data)
- return checksums[checksum_type](data)
|