aegis_sim.parameterization.parametermanager
1import logging 2import yaml 3import types 4 5from aegis_sim.parameterization.default_parameters import ( 6 get_default_parameters, 7 DEFAULT_PARAMETERS, 8 get_species_parameters, 9) 10 11 12class ParameterManager: 13 def init(self, custom_config_path, custom_input_params): 14 self.custom_config_path = custom_config_path 15 self.custom_input_params = custom_input_params 16 self.final_config = None 17 18 self.parameters = self() 19 20 def init_from_config(self, final_config, custom_config_path): 21 """Initialize from a saved config dict (used when resuming from checkpoint).""" 22 self.custom_config_path = custom_config_path 23 self.custom_input_params = {} 24 self.final_config = final_config.copy() 25 self.parameters = types.SimpleNamespace(**final_config) 26 logging.info("Parameters restored from checkpoint.") 27 28 def __call__(self): 29 """ 30 Getting parameters from three sources: 31 1. Default 32 2. Configuration file 33 3. Function arguments 34 35 When a parameter value is specified multiple times, 3 overwrites 2 which overwrites 1. 36 """ 37 38 default_parameters = get_default_parameters() 39 custom_config_params = self.read_config_file() 40 self.validate(custom_config_params) 41 for k in default_parameters.keys(): 42 if k in custom_config_params and default_parameters[k] != custom_config_params[k]: 43 logging.debug( 44 f"-- {k} is different in config ({custom_config_params[k]}) vs default ({default_parameters[k]})" 45 ) 46 47 SPECIES_PRESET = custom_config_params.get("SPECIES_PRESET", default_parameters["SPECIES_PRESET"]) 48 species_config_params = get_species_parameters(SPECIES_PRESET) 49 50 logging.info(f"Using {SPECIES_PRESET} as species preset: " + repr(species_config_params) + ".") 51 52 # Fuse 53 params = {} 54 params.update(default_parameters) 55 params.update(species_config_params) 56 params.update(custom_config_params) 57 params.update(self.custom_input_params) 58 59 self.final_config = params.copy() 60 61 # convert to types.SimpleNamespace 62 params = types.SimpleNamespace(**params) 63 logging.info("Final parameters to use in the simulation: " + repr(params) + ".") 64 return params 65 66 def read_config_file(self): 67 68 # No configuration file specified 69 if self.custom_config_path == "": 70 logging.info("No configuration file has been specified.") 71 return {} 72 73 # Configuration file specified... 74 with open(self.custom_config_path, "r") as file_: 75 ccp = yaml.safe_load(file_) 76 77 # ... but it is empty 78 if ccp is None: 79 logging.info("Configuration file is empty.") 80 ccp = {} 81 82 return ccp 83 84 @staticmethod 85 def validate(pdict, validate_serverrange=False): 86 for key, val in pdict.items(): 87 # Validate key 88 if all(key != p.key for p in DEFAULT_PARAMETERS.values()): 89 raise ValueError(f"'{key}' is not a valid parameter name") 90 91 # Validate value type and range 92 DEFAULT_PARAMETERS[key].validate_dtype(val) 93 DEFAULT_PARAMETERS[key].validate_inrange(val) 94 95 if validate_serverrange: 96 DEFAULT_PARAMETERS[key].validate_serverrange(val)
class
ParameterManager:
13class ParameterManager: 14 def init(self, custom_config_path, custom_input_params): 15 self.custom_config_path = custom_config_path 16 self.custom_input_params = custom_input_params 17 self.final_config = None 18 19 self.parameters = self() 20 21 def init_from_config(self, final_config, custom_config_path): 22 """Initialize from a saved config dict (used when resuming from checkpoint).""" 23 self.custom_config_path = custom_config_path 24 self.custom_input_params = {} 25 self.final_config = final_config.copy() 26 self.parameters = types.SimpleNamespace(**final_config) 27 logging.info("Parameters restored from checkpoint.") 28 29 def __call__(self): 30 """ 31 Getting parameters from three sources: 32 1. Default 33 2. Configuration file 34 3. Function arguments 35 36 When a parameter value is specified multiple times, 3 overwrites 2 which overwrites 1. 37 """ 38 39 default_parameters = get_default_parameters() 40 custom_config_params = self.read_config_file() 41 self.validate(custom_config_params) 42 for k in default_parameters.keys(): 43 if k in custom_config_params and default_parameters[k] != custom_config_params[k]: 44 logging.debug( 45 f"-- {k} is different in config ({custom_config_params[k]}) vs default ({default_parameters[k]})" 46 ) 47 48 SPECIES_PRESET = custom_config_params.get("SPECIES_PRESET", default_parameters["SPECIES_PRESET"]) 49 species_config_params = get_species_parameters(SPECIES_PRESET) 50 51 logging.info(f"Using {SPECIES_PRESET} as species preset: " + repr(species_config_params) + ".") 52 53 # Fuse 54 params = {} 55 params.update(default_parameters) 56 params.update(species_config_params) 57 params.update(custom_config_params) 58 params.update(self.custom_input_params) 59 60 self.final_config = params.copy() 61 62 # convert to types.SimpleNamespace 63 params = types.SimpleNamespace(**params) 64 logging.info("Final parameters to use in the simulation: " + repr(params) + ".") 65 return params 66 67 def read_config_file(self): 68 69 # No configuration file specified 70 if self.custom_config_path == "": 71 logging.info("No configuration file has been specified.") 72 return {} 73 74 # Configuration file specified... 75 with open(self.custom_config_path, "r") as file_: 76 ccp = yaml.safe_load(file_) 77 78 # ... but it is empty 79 if ccp is None: 80 logging.info("Configuration file is empty.") 81 ccp = {} 82 83 return ccp 84 85 @staticmethod 86 def validate(pdict, validate_serverrange=False): 87 for key, val in pdict.items(): 88 # Validate key 89 if all(key != p.key for p in DEFAULT_PARAMETERS.values()): 90 raise ValueError(f"'{key}' is not a valid parameter name") 91 92 # Validate value type and range 93 DEFAULT_PARAMETERS[key].validate_dtype(val) 94 DEFAULT_PARAMETERS[key].validate_inrange(val) 95 96 if validate_serverrange: 97 DEFAULT_PARAMETERS[key].validate_serverrange(val)
def
init_from_config(self, final_config, custom_config_path):
21 def init_from_config(self, final_config, custom_config_path): 22 """Initialize from a saved config dict (used when resuming from checkpoint).""" 23 self.custom_config_path = custom_config_path 24 self.custom_input_params = {} 25 self.final_config = final_config.copy() 26 self.parameters = types.SimpleNamespace(**final_config) 27 logging.info("Parameters restored from checkpoint.")
Initialize from a saved config dict (used when resuming from checkpoint).
def
read_config_file(self):
67 def read_config_file(self): 68 69 # No configuration file specified 70 if self.custom_config_path == "": 71 logging.info("No configuration file has been specified.") 72 return {} 73 74 # Configuration file specified... 75 with open(self.custom_config_path, "r") as file_: 76 ccp = yaml.safe_load(file_) 77 78 # ... but it is empty 79 if ccp is None: 80 logging.info("Configuration file is empty.") 81 ccp = {} 82 83 return ccp
@staticmethod
def
validate(pdict, validate_serverrange=False):
85 @staticmethod 86 def validate(pdict, validate_serverrange=False): 87 for key, val in pdict.items(): 88 # Validate key 89 if all(key != p.key for p in DEFAULT_PARAMETERS.values()): 90 raise ValueError(f"'{key}' is not a valid parameter name") 91 92 # Validate value type and range 93 DEFAULT_PARAMETERS[key].validate_dtype(val) 94 DEFAULT_PARAMETERS[key].validate_inrange(val) 95 96 if validate_serverrange: 97 DEFAULT_PARAMETERS[key].validate_serverrange(val)