forked from Mirror/pmbootstrap
config: sanity check via types (MR 2252)
Replace the "sanity_check" code with type checking built into the Config __setattr__ operator. This keeps all the Config related code in one place. Signed-off-by: Caleb Connolly <caleb@postmarketos.org>
This commit is contained in:
parent
7a8deb0f5e
commit
1a01738d50
6 changed files with 39 additions and 43 deletions
|
@ -1,5 +1,6 @@
|
|||
|
||||
from copy import deepcopy
|
||||
import enum
|
||||
import multiprocessing
|
||||
from typing import Any, List, Dict, TypedDict
|
||||
from pathlib import Path
|
||||
|
@ -11,6 +12,20 @@ class Mirrors(TypedDict):
|
|||
systemd: str
|
||||
|
||||
|
||||
class SystemdConfig(enum.Enum):
|
||||
DEFAULT = "default"
|
||||
ALWAYS = "always"
|
||||
NEVER = "never"
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
@staticmethod
|
||||
def choices() -> List[str]:
|
||||
return [e.value for e in SystemdConfig]
|
||||
|
||||
|
||||
class Config():
|
||||
aports: List[Path] = [Path(os.path.expanduser("~") +
|
||||
"/.local/var/pmbootstrap/cache_git/pmaports")]
|
||||
|
@ -41,7 +56,7 @@ class Config():
|
|||
ssh_key_glob: str = "~/.ssh/id_*.pub"
|
||||
ssh_keys: bool = False
|
||||
sudo_timer: bool = False
|
||||
systemd: str = "default"
|
||||
systemd: SystemdConfig = SystemdConfig.DEFAULT
|
||||
timezone: str = "GMT"
|
||||
ui: str = "console"
|
||||
ui_extras: bool = False
|
||||
|
@ -55,7 +70,7 @@ class Config():
|
|||
# Make sure we aren't modifying the class defaults
|
||||
for key in Config.__annotations__.keys():
|
||||
setattr(self, key, deepcopy(Config.get_default(key)))
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def keys() -> List[str]:
|
||||
|
@ -78,22 +93,29 @@ class Config():
|
|||
raise ValueError(f"Invalid dotted key: {dotted_key}")
|
||||
|
||||
|
||||
def __setattr__(self, key: str, value: str):
|
||||
def __setattr__(self, key: str, value: Any):
|
||||
"""Allow for setattr() to be used with a dotted key
|
||||
to set nested dictionaries (e.g. "mirrors.alpine")."""
|
||||
keys = key.split(".")
|
||||
if len(keys) == 1:
|
||||
super(Config, self).__setattr__(key, value)
|
||||
_type = type(getattr(Config, key))
|
||||
try:
|
||||
super(Config, self).__setattr__(key, _type(value))
|
||||
except ValueError:
|
||||
msg = f"Invalid value for '{key}': '{value}' "
|
||||
if issubclass(_type, enum.Enum):
|
||||
valid = [x.value for x in _type]
|
||||
msg += f"(valid values: {', '.join(valid)})"
|
||||
else:
|
||||
msg += f"(expected {_type}, got {type(value)})"
|
||||
raise ValueError(msg)
|
||||
elif len(keys) == 2:
|
||||
#print(f"cfgset, before: {super(Config, self).__getattribute__(keys[0])[keys[1]]}")
|
||||
super(Config, self).__getattribute__(keys[0])[keys[1]] = value
|
||||
#print(f"cfgset, after: {super(Config, self).__getattribute__(keys[0])[keys[1]]}")
|
||||
else:
|
||||
raise ValueError(f"Invalid dotted key: {key}")
|
||||
|
||||
|
||||
def __getattribute__(self, key: str) -> str:
|
||||
#print(repr(self))
|
||||
def __getattribute__(self, key: str) -> Any:
|
||||
"""Allow for getattr() to be used with a dotted key
|
||||
to get nested dictionaries (e.g. "mirrors.alpine")."""
|
||||
keys = key.split(".")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue