You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

121 lines
3.3 KiB

"""
Copyright (C) 2019 Splunk Inc. All Rights Reserved.
"""
import json
from .fields import BaseField, FieldValidationError, StringField
from .migrators import Migrator
import collections
def get_schema_version():
"""
Returns the current schema version
:return: the schema version
:rtype: str
"""
from . import __version__
return __version__
class ValidationError(Exception):
"""Raised when a data or validation error occurs on the model"""
class BaseModel(object):
"""A base model implementation that provides additional helper functionality."""
# Helper object to handle data migrations between different model versions
migrator = Migrator()
# The model version
version = StringField(default=get_schema_version, required=True)
def __init__(self, data=None, auto_validate=True, auto_migrate=True):
"""
:param dict data: The model data
:param auto_validate bool: Whether to automatically call validate on the model
:param auto_migrate bool: Whether to automatically migrate the model's data to the latest model version
"""
if data is None:
data = {}
if data.get('version', None) is None:
data['version'] = self.version.default()
if auto_migrate:
data = self.migrator.migrate(data)
self._fields = self._extract_fields()
self._populate_model(self._fields, data)
if auto_validate:
self.validate()
def validate(self):
"""
Validates the current model based on its schema of fields.
"""
for name, field in self._fields.items():
value = getattr(self, name)
try:
field.validate(value)
except FieldValidationError as ex:
raise ValidationError('"{}" {}'.format(name, ex))
def raw_data(self):
"""
Returns a dict obj that represents the current model
:return: a dict
:rtype: dict
"""
obj = {}
for name in list(self._fields.keys()):
obj[name] = getattr(self, name)
return obj
def json(self):
"""
Returns a JSON string that represents the current model
:return: a JSON string
:rtype: str
"""
return json.dumps(self.raw_data())
def _extract_fields(self):
"""
Returns a mapping dict of field name to field types
:return: a dict from str to an instance of BaseField
:rtype: dict
"""
fields = {}
class_items = list(BaseModel.__dict__.items()) + list(self.__class__.__dict__.items())
for name, value in class_items:
if not isinstance(value, BaseField):
continue
fields[name] = value
return fields
def _populate_model(self, fields, data):
"""
Populates the model with values from data for each field
:param dict fields: a dict from field name to BaseField
:param dict data: a dict from field name to data value
"""
for name, field in fields.items():
default = field.default
value = data.get(name, None)
if value is None and default is not None:
value = default() if isinstance(default, collections.Callable) else default
setattr(self, name, value)