config: sanity check via types (MR 2252)

Replace the "sanity_check" code with type checking built into the Config
__setattr__ operator.

This keeps all the Config related code in one place.

Signed-off-by: Caleb Connolly <caleb@postmarketos.org>
This commit is contained in:
Caleb Connolly 2024-06-09 03:27:45 +02:00 committed by Oliver Smith
parent 7a8deb0f5e
commit 1a01738d50
No known key found for this signature in database
GPG key ID: 5AE7F5513E0885CB
6 changed files with 39 additions and 43 deletions

View file

@ -1,5 +1,6 @@
from copy import deepcopy
import enum
import multiprocessing
from typing import Any, List, Dict, TypedDict
from pathlib import Path
@ -11,6 +12,20 @@ class Mirrors(TypedDict):
systemd: str
class SystemdConfig(enum.Enum):
DEFAULT = "default"
ALWAYS = "always"
NEVER = "never"
def __str__(self) -> str:
return self.value
@staticmethod
def choices() -> List[str]:
return [e.value for e in SystemdConfig]
class Config():
aports: List[Path] = [Path(os.path.expanduser("~") +
"/.local/var/pmbootstrap/cache_git/pmaports")]
@ -41,7 +56,7 @@ class Config():
ssh_key_glob: str = "~/.ssh/id_*.pub"
ssh_keys: bool = False
sudo_timer: bool = False
systemd: str = "default"
systemd: SystemdConfig = SystemdConfig.DEFAULT
timezone: str = "GMT"
ui: str = "console"
ui_extras: bool = False
@ -55,7 +70,7 @@ class Config():
# Make sure we aren't modifying the class defaults
for key in Config.__annotations__.keys():
setattr(self, key, deepcopy(Config.get_default(key)))
@staticmethod
def keys() -> List[str]:
@ -78,22 +93,29 @@ class Config():
raise ValueError(f"Invalid dotted key: {dotted_key}")
def __setattr__(self, key: str, value: str):
def __setattr__(self, key: str, value: Any):
"""Allow for setattr() to be used with a dotted key
to set nested dictionaries (e.g. "mirrors.alpine")."""
keys = key.split(".")
if len(keys) == 1:
super(Config, self).__setattr__(key, value)
_type = type(getattr(Config, key))
try:
super(Config, self).__setattr__(key, _type(value))
except ValueError:
msg = f"Invalid value for '{key}': '{value}' "
if issubclass(_type, enum.Enum):
valid = [x.value for x in _type]
msg += f"(valid values: {', '.join(valid)})"
else:
msg += f"(expected {_type}, got {type(value)})"
raise ValueError(msg)
elif len(keys) == 2:
#print(f"cfgset, before: {super(Config, self).__getattribute__(keys[0])[keys[1]]}")
super(Config, self).__getattribute__(keys[0])[keys[1]] = value
#print(f"cfgset, after: {super(Config, self).__getattribute__(keys[0])[keys[1]]}")
else:
raise ValueError(f"Invalid dotted key: {key}")
def __getattribute__(self, key: str) -> str:
#print(repr(self))
def __getattribute__(self, key: str) -> Any:
"""Allow for getattr() to be used with a dotted key
to get nested dictionaries (e.g. "mirrors.alpine")."""
keys = key.split(".")