aegis_sim.parameterization.parametermanager

 1import logging
 2import yaml
 3import types
 4
 5from aegis_sim.parameterization.default_parameters import (
 6    get_default_parameters,
 7    DEFAULT_PARAMETERS,
 8    get_species_parameters,
 9)
10
11
12class ParameterManager:
13    def init(self, custom_config_path, custom_input_params):
14        self.custom_config_path = custom_config_path
15        self.custom_input_params = custom_input_params
16        self.final_config = None
17
18        self.parameters = self()
19
20    def __call__(self):
21        """
22        Getting parameters from three sources:
23        1. Default
24        2. Configuration file
25        3. Function arguments
26
27        When a parameter value is specified multiple times, 3 overwrites 2 which overwrites 1.
28        """
29
30        default_parameters = get_default_parameters()
31        custom_config_params = self.read_config_file()
32        self.validate(custom_config_params)
33        for k in default_parameters.keys():
34            if k in custom_config_params and default_parameters[k] != custom_config_params[k]:
35                logging.debug(
36                    f"-- {k} is different in config ({custom_config_params[k]}) vs default ({default_parameters[k]})"
37                )
38
39        SPECIES_PRESET = custom_config_params.get("SPECIES_PRESET", default_parameters["SPECIES_PRESET"])
40        species_config_params = get_species_parameters(SPECIES_PRESET)
41
42        logging.info(f"Using {SPECIES_PRESET} as species preset: " + repr(species_config_params) + ".")
43
44        # Fuse
45        params = {}
46        params.update(default_parameters)
47        params.update(species_config_params)
48        params.update(custom_config_params)
49        params.update(self.custom_input_params)
50
51        self.final_config = params.copy()
52
53        # convert to types.SimpleNamespace
54        params = types.SimpleNamespace(**params)
55        logging.info("Final parameters to use in the simulation: " + repr(params) + ".")
56        return params
57
58    def read_config_file(self):
59
60        # No configuration file specified
61        if self.custom_config_path == "":
62            logging.info("No configuration file has been specified.")
63            return {}
64
65        # Configuration file specified...
66        with open(self.custom_config_path, "r") as file_:
67            ccp = yaml.safe_load(file_)
68
69        # ... but it is empty
70        if ccp is None:
71            logging.info("Configuration file is empty.")
72            ccp = {}
73
74        return ccp
75
76    @staticmethod
77    def validate(pdict, validate_serverrange=False):
78        for key, val in pdict.items():
79            # Validate key
80            if all(key != p.key for p in DEFAULT_PARAMETERS.values()):
81                raise ValueError(f"'{key}' is not a valid parameter name")
82
83            # Validate value type and range
84            DEFAULT_PARAMETERS[key].validate_dtype(val)
85            DEFAULT_PARAMETERS[key].validate_inrange(val)
86
87            if validate_serverrange:
88                DEFAULT_PARAMETERS[key].validate_serverrange(val)
class ParameterManager:
13class ParameterManager:
14    def init(self, custom_config_path, custom_input_params):
15        self.custom_config_path = custom_config_path
16        self.custom_input_params = custom_input_params
17        self.final_config = None
18
19        self.parameters = self()
20
21    def __call__(self):
22        """
23        Getting parameters from three sources:
24        1. Default
25        2. Configuration file
26        3. Function arguments
27
28        When a parameter value is specified multiple times, 3 overwrites 2 which overwrites 1.
29        """
30
31        default_parameters = get_default_parameters()
32        custom_config_params = self.read_config_file()
33        self.validate(custom_config_params)
34        for k in default_parameters.keys():
35            if k in custom_config_params and default_parameters[k] != custom_config_params[k]:
36                logging.debug(
37                    f"-- {k} is different in config ({custom_config_params[k]}) vs default ({default_parameters[k]})"
38                )
39
40        SPECIES_PRESET = custom_config_params.get("SPECIES_PRESET", default_parameters["SPECIES_PRESET"])
41        species_config_params = get_species_parameters(SPECIES_PRESET)
42
43        logging.info(f"Using {SPECIES_PRESET} as species preset: " + repr(species_config_params) + ".")
44
45        # Fuse
46        params = {}
47        params.update(default_parameters)
48        params.update(species_config_params)
49        params.update(custom_config_params)
50        params.update(self.custom_input_params)
51
52        self.final_config = params.copy()
53
54        # convert to types.SimpleNamespace
55        params = types.SimpleNamespace(**params)
56        logging.info("Final parameters to use in the simulation: " + repr(params) + ".")
57        return params
58
59    def read_config_file(self):
60
61        # No configuration file specified
62        if self.custom_config_path == "":
63            logging.info("No configuration file has been specified.")
64            return {}
65
66        # Configuration file specified...
67        with open(self.custom_config_path, "r") as file_:
68            ccp = yaml.safe_load(file_)
69
70        # ... but it is empty
71        if ccp is None:
72            logging.info("Configuration file is empty.")
73            ccp = {}
74
75        return ccp
76
77    @staticmethod
78    def validate(pdict, validate_serverrange=False):
79        for key, val in pdict.items():
80            # Validate key
81            if all(key != p.key for p in DEFAULT_PARAMETERS.values()):
82                raise ValueError(f"'{key}' is not a valid parameter name")
83
84            # Validate value type and range
85            DEFAULT_PARAMETERS[key].validate_dtype(val)
86            DEFAULT_PARAMETERS[key].validate_inrange(val)
87
88            if validate_serverrange:
89                DEFAULT_PARAMETERS[key].validate_serverrange(val)
def init(self, custom_config_path, custom_input_params):
14    def init(self, custom_config_path, custom_input_params):
15        self.custom_config_path = custom_config_path
16        self.custom_input_params = custom_input_params
17        self.final_config = None
18
19        self.parameters = self()
def read_config_file(self):
59    def read_config_file(self):
60
61        # No configuration file specified
62        if self.custom_config_path == "":
63            logging.info("No configuration file has been specified.")
64            return {}
65
66        # Configuration file specified...
67        with open(self.custom_config_path, "r") as file_:
68            ccp = yaml.safe_load(file_)
69
70        # ... but it is empty
71        if ccp is None:
72            logging.info("Configuration file is empty.")
73            ccp = {}
74
75        return ccp
@staticmethod
def validate(pdict, validate_serverrange=False):
77    @staticmethod
78    def validate(pdict, validate_serverrange=False):
79        for key, val in pdict.items():
80            # Validate key
81            if all(key != p.key for p in DEFAULT_PARAMETERS.values()):
82                raise ValueError(f"'{key}' is not a valid parameter name")
83
84            # Validate value type and range
85            DEFAULT_PARAMETERS[key].validate_dtype(val)
86            DEFAULT_PARAMETERS[key].validate_inrange(val)
87
88            if validate_serverrange:
89                DEFAULT_PARAMETERS[key].validate_serverrange(val)