aegis_sim.utilities.container

  1import pandas as pd
  2import pathlib
  3import logging
  4import json
  5import yaml
  6from typing import Union
  7import numpy as np
  8import psutil
  9
 10from aegis_sim.dataclasses.population import Population
 11from aegis_sim.constants import VALID_CAUSES_OF_DEATH
 12from aegis_sim.recording.ticker import Ticker
 13
 14
 15# TODO for analysis:
 16# TODO add 0 to survivorship, think about edge cases
 17# TODO clean up indices, columns and dtypes
 18# TODO be explicit about aggregation function
 19
 20
 21class Container:
 22    """
 23    Reads and reformats output files so they are available for internal and external use (prepare for export).
 24    """
 25
 26    def __init__(self, basepath):
 27        self.basepath = pathlib.Path(
 28            basepath
 29        ).absolute()  # If path to config file is /path/_.yml, then basepath is /path/_
 30        self.name = self.basepath.stem
 31        self.data = {}
 32        # self.set_paths()
 33        self.paths = None
 34        self.ticker = None
 35
 36    def set_paths(self):
 37        # TODO smarter way of listing paths; you are capturing te files with number keys e.g. '6': ... /te/6.csv; that's silly
 38        # TODO these are repeated elsewhere, e.g. path for ticker
 39        self.paths = {
 40            path.stem: path for path in self.basepath.glob("**/*") if path.is_file() and path.suffix == ".csv"
 41        }
 42        self.paths["log"] = self.basepath / "progress.log"
 43        self.paths["ticker"] = self.basepath / "ticker.txt"
 44        self.paths["simpleprogress"] = self.basepath / "simpleprogress.log"
 45        self.paths["output_summary"] = self.basepath / "output_summary.json"
 46        self.paths["input_summary"] = self.basepath / "input_summary.json"
 47        self.paths["envdriftmap"] = self.basepath / "envdriftmap.csv"
 48        self.paths["snapshots"] = {}
 49        for kind in ("demography", "phenotypes", "genotypes"):
 50            self.paths["snapshots"][kind] = sorted(
 51                (self.basepath / "snapshots" / kind).glob("*"),
 52                key=lambda path: int(path.stem),
 53            )
 54        self.paths["pickles"] = sorted(
 55            (self.basepath / "pickles").glob("*"),
 56            key=lambda path: int(path.stem),
 57        )
 58        # self.paths["te"] = sorted(
 59        #     (self.basepath / "te").glob("*"),
 60        #     key=lambda path: int(path.stem),
 61        # )
 62        self.paths["popsize_before_reproduction"] = self.basepath / "popsize_before_reproduction.csv"
 63        self.paths["popsize_after_reproduction"] = self.basepath / "popsize_after_reproduction.csv"
 64        self.paths["eggnum_after_reproduction"] = self.basepath / "eggnum_after_reproduction.csv"
 65
 66        if not self.paths["log"].is_file():
 67            logging.error(f"No AEGIS log found at path {self.paths['log']}.")
 68
 69    def get_paths(self):
 70        if self.paths is None:
 71            self.set_paths()
 72        return self.paths
 73
 74    def get_path(self, name):
 75        if self.paths is None:
 76            self.set_paths()
 77        return self.paths[name]
 78
 79    def get_record_structure(self):
 80        # TODO
 81        return
 82
 83    def report(self):
 84        """Report present and missing files"""
 85        # TODO
 86        return
 87
 88    def export(self):
 89        """Export all primary data from the container using general formats"""
 90        # TODO
 91        return
 92
 93    @staticmethod
 94    def stop_process(pid, kind_of_process):
 95        try:
 96            logging.info(f"Terminating {kind_of_process} process with PID {pid}...")
 97            process = psutil.Process(pid)
 98            process.terminate()  # or process.kill()
 99            process.wait()  # Optional: Wait for the process to be fully terminated
100            logging.info(f"Process with PID {pid} terminated successfully.")
101        except psutil.NoSuchProcess:
102            logging.warning(f"No process found with PID {pid}.")
103        except psutil.AccessDenied:
104            logging.warning(f"Access denied when trying to terminate the process with PID {pid}.")
105        except Exception as e:
106            logging.error(f"An error occurred: {e}")
107
108    def terminate(self):
109        pid = self.get_input_summary()["pid"]
110        assert pid is not None
111        self.stop_process(pid, "simulation")
112        tpid = self.get_input_summary()["ticker_pid"]
113        assert tpid is not None
114        self.stop_process(tpid, "ticker")
115
116    ############
117    # METADATA #
118    ############
119
120    def get_log(self, reload=True):
121        if ("log" not in self.data) or reload:
122            df = pd.read_csv(self.get_path("log"), sep="|")
123            df.columns = [x.strip() for x in df.columns]
124
125            def dhm_inverse(dhm):
126                nums = dhm.replace("`", ":").split(":")
127                return int(nums[0]) * 24 * 60 + int(nums[1]) * 60 + int(nums[2])
128
129            # TODO resolve deprecated function
130            try:
131                df[["ETA", "t1M", "runtime"]].map(dhm_inverse)
132            except:
133                df[["ETA", "t1M", "runtime"]].applymap(dhm_inverse)
134            self.data["log"] = df
135        return self.data["log"]
136
137    def get_simple_log(self):
138        try:
139            with open(self.get_path("simpleprogress"), "r") as file_:
140                text = file_.read()
141                step, steps_per_simulation = text.split("/")
142                return int(step), int(steps_per_simulation)
143        except:
144            logging.error(f"No simpleprogress.log found at {self.get_path('simpleprogress')}")
145
146    def get_ticker(self):
147        if self.ticker is None:
148            TICKER_RATE = self.get_final_config()["TICKER_RATE"]
149            self.ticker = Ticker(TICKER_RATE=TICKER_RATE, odir=self.get_path("ticker").parent)
150        return self.ticker
151
152    def get_config(self):
153        if "config" not in self.data:
154            path = self.basepath.parent / f"{self.basepath}.yml"
155            with open(path, "r") as file_:
156                custom_config = yaml.safe_load(file_)
157            # default_config = get_default_parameters()
158            if custom_config is None:
159                custom_config = {}
160            # self.data["config"] = {**default_config, **custom_config}
161            self.data["config"] = custom_config
162        return self.data["config"]
163
164    def get_final_config(self):
165        if "final_config" not in self.data:
166            path = self.basepath / "final_config.yml"
167            with open(path, "r") as file_:
168                final_config = yaml.safe_load(file_)
169            if final_config is None:
170                final_config = {}
171            self.data["final_config"] = final_config
172        return self.data["final_config"]
173
174    def get_generations_until_interval(self):
175        """Return Series of number of generations simulated up until interval i"""
176        # TODO beware that snapshots are not timed linearly; there is a bunch of snapshots at the end of the simulation
177        aar = self.get_average_age_at_reproduction()
178        aar.iloc[0] = np.inf  # No time has passed, so no generations yet
179        IR = self.get_final_config()["INTERVAL_RATE"]
180        aar = aar.pipe(lambda s: IR / s).cumsum()
181        return aar
182
183    def get_output_summary(self) -> Union[dict, None]:
184        path = self.get_path("output_summary")
185        if path.exists():
186            return self._read_json(path)
187        return {}
188
189    def get_input_summary(self):
190        return self._read_json(self.get_path("input_summary"))
191
192    def get_envidriftmap(self):
193        path = self.get_path("envdriftmap")
194        if path.exists():
195            return pd.read_csv(self.get_path("envdriftmap"), header=None)
196        return None
197
198    def get_phenomap(self):
199        path = self.get_path("phenomap")
200        if path.exists():
201            return pd.read_csv(self.get_path("phenomap"))
202        return None
203
204    ##########
205    # TABLES #
206    ##########
207
208    def get_birth_table_observed_interval(self, normalize=False):
209        """
210        Observed data.
211        Number of births (int) per parental age during an interval of length INTERVAL_RATE.
212        columns.name == parental_age (int)
213        index.name == interval (int)
214        """
215        table = self._read_df("age_at_birth")
216        if normalize:
217            table = table.div(table.sum(1), axis=0)
218        table.index.names = ["interval"]
219        table.columns.names = ["parental_age"]
220        table.columns = table.columns.astype(int)
221        return table
222
223    def get_life_table_observed_interval(self, normalize=False):
224        """
225        Observed data.
226        Number of individuals (int) per age class observed during an interval of length INTERVAL_RATE.
227        columns.name == age_class (int)
228        index.name == interval (int)
229        """
230        table = self._read_df("additive_age_structure")
231        table.index.names = ["interval"]
232        table.columns.names = ["age_class"]
233        table.columns = table.columns.astype(int)
234        # NOTE normalize by sum
235        if normalize:
236            table = table.div(table.sum(1), axis=0)
237        return table
238
239    def get_life_table_observed_snapshot(self, record_index: int, normalize=False):
240        """
241        Observed data. Series.
242        Number of individuals (int) per age class observed at some simulation step captured by the record of index record_index.
243        name == count
244        index.name == age_class
245        """
246        AGE_LIMIT = self.get_final_config()["AGE_LIMIT"]
247        table = (
248            self.get_demography_observed_snapshot(record_index)
249            .ages.value_counts()
250            .reindex(range(AGE_LIMIT), fill_value=0)
251        )
252        table.index.names = ["age_class"]
253        return table
254
255    def get_death_table_observed_interval(self, normalize=False):
256        """
257        Observed data. Has a MultiIndex.
258        Number of deaths (int) per age class observed during an interval of length INTERVAL_RATE.
259        columns.name == age_class (int)
260        index.names == ["interval", "cause_of_death"] (int, str)
261        """
262        # TODO think about position of axes
263        table = (
264            pd.concat({causeofdeath: self._read_df(f"age_at_{causeofdeath}") for causeofdeath in VALID_CAUSES_OF_DEATH})
265            .swaplevel()
266            .sort_index(level=0)
267        )
268        table.index.names = ["interval", "cause_of_death"]
269        table.columns.names = ["age_class"]
270        table.columns = table.columns.astype(int)
271        return table
272
273    #######################
274    # TABLES : derivative #
275    #######################
276
277    def get_surv_observed_interval(self):
278        # TODO this is not accurate; this assumes that the population is in an equilibrium, or it only works if the life table is sampling across a long period
279        lt = self.get_life_table_observed_interval()
280        lt = lt.pct_change(axis=1).shift(-1, axis=1).add(1).replace(np.inf, 1)
281        return lt
282
283    def get_fert_observed_interval(self):
284        lt = self.get_life_table_observed_interval()
285        bt = self.get_birth_table_observed_interval()
286        return bt / lt
287
288    ##########
289    # BASICS #
290    ##########
291
292    # TODO add better column and index names
293
294    def get_genotypes_intrinsic_snapshot(self, record_index):
295        """
296        columns .. bit index
297        index .. individual index
298        value .. True or False
299        """
300        # TODO let index denote the step at which the snapshot was taken
301        return self._read_snapshot("genotypes", record_index=record_index)
302
303    def get_phenotype_intrinsic_snapshot(self, trait, record_index):
304        """
305        columns .. phenotypic trait index
306        index .. individual index
307        value .. phenotypic trait value
308        """
309        # TODO organize by trait
310        # TODO let index denote the step at which the snapshot was taken
311        df = self._read_snapshot("phenotypes", record_index=record_index)
312        # df.columns = df.columns.str.split("_")
313        return df
314
315    def get_demography_observed_snapshot(self, record_index):
316        """
317        columns .. ages, births, birthdays, generations, sizes, sexes
318        index .. individual index
319        """
320        # TODO let index denote the step at which the snapshot was taken
321        return self._read_snapshot("demography", record_index=record_index)
322
323    def get_genotypes_intrinsic_interval(self, reload=True):
324        """
325        columns .. bit index
326        index .. record index
327        value .. mean bit value
328        """
329        # TODO check that they exist
330        df = pd.read_csv(self.get_path("genotypes"), header=[0, 1], index_col=None)
331        df.index = df.index.astype(int)
332        df.columns = df.columns.set_levels([df.columns.levels[0].astype(int), df.columns.levels[1].astype(int)])
333        df.index.names = ["interval"]
334        df.columns.names = ["bit_index", "ploidy"]
335        return df
336
337    def get_phenotype_intrinsic_interval(self, trait, reload=True):
338        """
339        columns .. age
340        index .. record index
341        value .. median phenotypic trait value
342        """
343        # TODO check that they exist
344        df = pd.read_csv(self.get_path("phenotypes"), header=[0, 1])
345        df.index.names = ["interval"]
346        df.index = df.index.astype(int)
347        df.columns.names = ["trait", "age_class"]
348        # TODO age_class is str
349        return df.xs(trait, axis=1)
350
351    def get_survival_analysis_TE_observed_interval(self, record_index):
352        """
353        columns .. T, E
354        index .. individual
355        value .. age at event, event (1 .. died, 0 .. alive)
356        """
357        # TODO error with T and E in the record; they are being appended on top
358        assert record_index < len(self.get_path("te")), "Index out of range"
359        data = pd.read_csv(self.get_path("te")[record_index], header=0)
360        data.index.names = ["individual"]
361        return data
362
363    def get_population_size_before_reproduction(self):
364        data = pd.read_csv(self.get_path("popsize_before_reproduction"), header=None)
365        data.index.names = ["steps"]
366        data.columns = ["popsize"]
367        return data
368
369    def get_population_size_after_reproduction(self):
370        data = pd.read_csv(self.get_path("popsize_after_reproduction"), header=None)
371        data.index.names = ["steps"]
372        data.columns = ["popsize"]
373        return data
374
375    def get_egg_number_after_reproduction(self):
376        data = pd.read_csv(self.get_path("eggnum_after_reproduction"), header=None)
377        data.index.names = ["steps"]
378        data.columns = ["number"]
379        return data
380
381    def get_resource_amount_before_scavenging(self):
382        data = pd.read_csv(self.get_path("resources_before_scavenging"), header=None)
383        data.index.names = ["steps"]
384        data.columns = ["resources"]
385        return data
386
387    def get_resource_amount_after_scavenging(self):
388        data = pd.read_csv(self.get_path("resources_after_scavenging"), header=None)
389        data.index.names = ["steps"]
390        data.columns = ["resources"]
391        return data
392
393    ###############
394    # DERIVATIVES #
395    ###############
396
397    def get_lifetime_reproduction(self):
398        survivorship = self.get_surv_observed_interval().cumprod(1)
399        fertility = self.get_fert_observed_interval()
400        return (survivorship * fertility).sum(axis=1)
401
402    def get_average_age_at_reproduction(self):
403        bt = self.get_birth_table_observed_interval()
404        n_offspring = bt.sum(1)
405        average_age_at_reproduction = (bt * bt.columns).sum(1) / n_offspring
406        return average_age_at_reproduction
407
408    #############
409    # UTILITIES #
410    #############
411
412    def _file_exists(self, stem):
413        if self.paths is None:
414            self.set_paths()
415        return stem in self.paths
416
417    def has_ticker_stopped(self):
418        return self.get_ticker().has_stopped()
419
420    def _read_df(self, stem, reload=True):
421        file_read = stem in self.data
422        # TODO Read also files that are not .csv
423
424        if not self._file_exists(stem):
425            logging.error(f"File {self.get_path(stem)} des not exist.")
426        elif (not file_read) or reload:
427            self.data[stem] = pd.read_csv(self.get_path(stem), header=0)
428
429        return self.data.get(stem, pd.DataFrame())
430
431    @staticmethod
432    def _read_json(path):
433        if not path.exists():
434            logging.warning(f"'{path}' does not exist.")
435            return None
436        with open(path, "r") as file_:
437            return json.load(file_)
438
439    def _read_snapshot(self, record_type, record_index):
440        assert record_type in self.get_path("snapshots"), f"No records of '{record_type}' can be found in snapshots"
441        assert record_index < len(self.get_path("snapshots")[record_type]), "Index out of range"
442        return pd.read_feather(self.get_path("snapshots")[record_type][record_index])
443
444    def _read_pickle(self, record_index):
445        assert record_index < len(self.get_path("pickles")), "Index out of range"
446        return Population.load_pickle_from(self.get_path("pickles")[record_index])
class Container:
 22class Container:
 23    """
 24    Reads and reformats output files so they are available for internal and external use (prepare for export).
 25    """
 26
 27    def __init__(self, basepath):
 28        self.basepath = pathlib.Path(
 29            basepath
 30        ).absolute()  # If path to config file is /path/_.yml, then basepath is /path/_
 31        self.name = self.basepath.stem
 32        self.data = {}
 33        # self.set_paths()
 34        self.paths = None
 35        self.ticker = None
 36
 37    def set_paths(self):
 38        # TODO smarter way of listing paths; you are capturing te files with number keys e.g. '6': ... /te/6.csv; that's silly
 39        # TODO these are repeated elsewhere, e.g. path for ticker
 40        self.paths = {
 41            path.stem: path for path in self.basepath.glob("**/*") if path.is_file() and path.suffix == ".csv"
 42        }
 43        self.paths["log"] = self.basepath / "progress.log"
 44        self.paths["ticker"] = self.basepath / "ticker.txt"
 45        self.paths["simpleprogress"] = self.basepath / "simpleprogress.log"
 46        self.paths["output_summary"] = self.basepath / "output_summary.json"
 47        self.paths["input_summary"] = self.basepath / "input_summary.json"
 48        self.paths["envdriftmap"] = self.basepath / "envdriftmap.csv"
 49        self.paths["snapshots"] = {}
 50        for kind in ("demography", "phenotypes", "genotypes"):
 51            self.paths["snapshots"][kind] = sorted(
 52                (self.basepath / "snapshots" / kind).glob("*"),
 53                key=lambda path: int(path.stem),
 54            )
 55        self.paths["pickles"] = sorted(
 56            (self.basepath / "pickles").glob("*"),
 57            key=lambda path: int(path.stem),
 58        )
 59        # self.paths["te"] = sorted(
 60        #     (self.basepath / "te").glob("*"),
 61        #     key=lambda path: int(path.stem),
 62        # )
 63        self.paths["popsize_before_reproduction"] = self.basepath / "popsize_before_reproduction.csv"
 64        self.paths["popsize_after_reproduction"] = self.basepath / "popsize_after_reproduction.csv"
 65        self.paths["eggnum_after_reproduction"] = self.basepath / "eggnum_after_reproduction.csv"
 66
 67        if not self.paths["log"].is_file():
 68            logging.error(f"No AEGIS log found at path {self.paths['log']}.")
 69
 70    def get_paths(self):
 71        if self.paths is None:
 72            self.set_paths()
 73        return self.paths
 74
 75    def get_path(self, name):
 76        if self.paths is None:
 77            self.set_paths()
 78        return self.paths[name]
 79
 80    def get_record_structure(self):
 81        # TODO
 82        return
 83
 84    def report(self):
 85        """Report present and missing files"""
 86        # TODO
 87        return
 88
 89    def export(self):
 90        """Export all primary data from the container using general formats"""
 91        # TODO
 92        return
 93
 94    @staticmethod
 95    def stop_process(pid, kind_of_process):
 96        try:
 97            logging.info(f"Terminating {kind_of_process} process with PID {pid}...")
 98            process = psutil.Process(pid)
 99            process.terminate()  # or process.kill()
100            process.wait()  # Optional: Wait for the process to be fully terminated
101            logging.info(f"Process with PID {pid} terminated successfully.")
102        except psutil.NoSuchProcess:
103            logging.warning(f"No process found with PID {pid}.")
104        except psutil.AccessDenied:
105            logging.warning(f"Access denied when trying to terminate the process with PID {pid}.")
106        except Exception as e:
107            logging.error(f"An error occurred: {e}")
108
109    def terminate(self):
110        pid = self.get_input_summary()["pid"]
111        assert pid is not None
112        self.stop_process(pid, "simulation")
113        tpid = self.get_input_summary()["ticker_pid"]
114        assert tpid is not None
115        self.stop_process(tpid, "ticker")
116
117    ############
118    # METADATA #
119    ############
120
121    def get_log(self, reload=True):
122        if ("log" not in self.data) or reload:
123            df = pd.read_csv(self.get_path("log"), sep="|")
124            df.columns = [x.strip() for x in df.columns]
125
126            def dhm_inverse(dhm):
127                nums = dhm.replace("`", ":").split(":")
128                return int(nums[0]) * 24 * 60 + int(nums[1]) * 60 + int(nums[2])
129
130            # TODO resolve deprecated function
131            try:
132                df[["ETA", "t1M", "runtime"]].map(dhm_inverse)
133            except:
134                df[["ETA", "t1M", "runtime"]].applymap(dhm_inverse)
135            self.data["log"] = df
136        return self.data["log"]
137
138    def get_simple_log(self):
139        try:
140            with open(self.get_path("simpleprogress"), "r") as file_:
141                text = file_.read()
142                step, steps_per_simulation = text.split("/")
143                return int(step), int(steps_per_simulation)
144        except:
145            logging.error(f"No simpleprogress.log found at {self.get_path('simpleprogress')}")
146
147    def get_ticker(self):
148        if self.ticker is None:
149            TICKER_RATE = self.get_final_config()["TICKER_RATE"]
150            self.ticker = Ticker(TICKER_RATE=TICKER_RATE, odir=self.get_path("ticker").parent)
151        return self.ticker
152
153    def get_config(self):
154        if "config" not in self.data:
155            path = self.basepath.parent / f"{self.basepath}.yml"
156            with open(path, "r") as file_:
157                custom_config = yaml.safe_load(file_)
158            # default_config = get_default_parameters()
159            if custom_config is None:
160                custom_config = {}
161            # self.data["config"] = {**default_config, **custom_config}
162            self.data["config"] = custom_config
163        return self.data["config"]
164
165    def get_final_config(self):
166        if "final_config" not in self.data:
167            path = self.basepath / "final_config.yml"
168            with open(path, "r") as file_:
169                final_config = yaml.safe_load(file_)
170            if final_config is None:
171                final_config = {}
172            self.data["final_config"] = final_config
173        return self.data["final_config"]
174
175    def get_generations_until_interval(self):
176        """Return Series of number of generations simulated up until interval i"""
177        # TODO beware that snapshots are not timed linearly; there is a bunch of snapshots at the end of the simulation
178        aar = self.get_average_age_at_reproduction()
179        aar.iloc[0] = np.inf  # No time has passed, so no generations yet
180        IR = self.get_final_config()["INTERVAL_RATE"]
181        aar = aar.pipe(lambda s: IR / s).cumsum()
182        return aar
183
184    def get_output_summary(self) -> Union[dict, None]:
185        path = self.get_path("output_summary")
186        if path.exists():
187            return self._read_json(path)
188        return {}
189
190    def get_input_summary(self):
191        return self._read_json(self.get_path("input_summary"))
192
193    def get_envidriftmap(self):
194        path = self.get_path("envdriftmap")
195        if path.exists():
196            return pd.read_csv(self.get_path("envdriftmap"), header=None)
197        return None
198
199    def get_phenomap(self):
200        path = self.get_path("phenomap")
201        if path.exists():
202            return pd.read_csv(self.get_path("phenomap"))
203        return None
204
205    ##########
206    # TABLES #
207    ##########
208
209    def get_birth_table_observed_interval(self, normalize=False):
210        """
211        Observed data.
212        Number of births (int) per parental age during an interval of length INTERVAL_RATE.
213        columns.name == parental_age (int)
214        index.name == interval (int)
215        """
216        table = self._read_df("age_at_birth")
217        if normalize:
218            table = table.div(table.sum(1), axis=0)
219        table.index.names = ["interval"]
220        table.columns.names = ["parental_age"]
221        table.columns = table.columns.astype(int)
222        return table
223
224    def get_life_table_observed_interval(self, normalize=False):
225        """
226        Observed data.
227        Number of individuals (int) per age class observed during an interval of length INTERVAL_RATE.
228        columns.name == age_class (int)
229        index.name == interval (int)
230        """
231        table = self._read_df("additive_age_structure")
232        table.index.names = ["interval"]
233        table.columns.names = ["age_class"]
234        table.columns = table.columns.astype(int)
235        # NOTE normalize by sum
236        if normalize:
237            table = table.div(table.sum(1), axis=0)
238        return table
239
240    def get_life_table_observed_snapshot(self, record_index: int, normalize=False):
241        """
242        Observed data. Series.
243        Number of individuals (int) per age class observed at some simulation step captured by the record of index record_index.
244        name == count
245        index.name == age_class
246        """
247        AGE_LIMIT = self.get_final_config()["AGE_LIMIT"]
248        table = (
249            self.get_demography_observed_snapshot(record_index)
250            .ages.value_counts()
251            .reindex(range(AGE_LIMIT), fill_value=0)
252        )
253        table.index.names = ["age_class"]
254        return table
255
256    def get_death_table_observed_interval(self, normalize=False):
257        """
258        Observed data. Has a MultiIndex.
259        Number of deaths (int) per age class observed during an interval of length INTERVAL_RATE.
260        columns.name == age_class (int)
261        index.names == ["interval", "cause_of_death"] (int, str)
262        """
263        # TODO think about position of axes
264        table = (
265            pd.concat({causeofdeath: self._read_df(f"age_at_{causeofdeath}") for causeofdeath in VALID_CAUSES_OF_DEATH})
266            .swaplevel()
267            .sort_index(level=0)
268        )
269        table.index.names = ["interval", "cause_of_death"]
270        table.columns.names = ["age_class"]
271        table.columns = table.columns.astype(int)
272        return table
273
274    #######################
275    # TABLES : derivative #
276    #######################
277
278    def get_surv_observed_interval(self):
279        # TODO this is not accurate; this assumes that the population is in an equilibrium, or it only works if the life table is sampling across a long period
280        lt = self.get_life_table_observed_interval()
281        lt = lt.pct_change(axis=1).shift(-1, axis=1).add(1).replace(np.inf, 1)
282        return lt
283
284    def get_fert_observed_interval(self):
285        lt = self.get_life_table_observed_interval()
286        bt = self.get_birth_table_observed_interval()
287        return bt / lt
288
289    ##########
290    # BASICS #
291    ##########
292
293    # TODO add better column and index names
294
295    def get_genotypes_intrinsic_snapshot(self, record_index):
296        """
297        columns .. bit index
298        index .. individual index
299        value .. True or False
300        """
301        # TODO let index denote the step at which the snapshot was taken
302        return self._read_snapshot("genotypes", record_index=record_index)
303
304    def get_phenotype_intrinsic_snapshot(self, trait, record_index):
305        """
306        columns .. phenotypic trait index
307        index .. individual index
308        value .. phenotypic trait value
309        """
310        # TODO organize by trait
311        # TODO let index denote the step at which the snapshot was taken
312        df = self._read_snapshot("phenotypes", record_index=record_index)
313        # df.columns = df.columns.str.split("_")
314        return df
315
316    def get_demography_observed_snapshot(self, record_index):
317        """
318        columns .. ages, births, birthdays, generations, sizes, sexes
319        index .. individual index
320        """
321        # TODO let index denote the step at which the snapshot was taken
322        return self._read_snapshot("demography", record_index=record_index)
323
324    def get_genotypes_intrinsic_interval(self, reload=True):
325        """
326        columns .. bit index
327        index .. record index
328        value .. mean bit value
329        """
330        # TODO check that they exist
331        df = pd.read_csv(self.get_path("genotypes"), header=[0, 1], index_col=None)
332        df.index = df.index.astype(int)
333        df.columns = df.columns.set_levels([df.columns.levels[0].astype(int), df.columns.levels[1].astype(int)])
334        df.index.names = ["interval"]
335        df.columns.names = ["bit_index", "ploidy"]
336        return df
337
338    def get_phenotype_intrinsic_interval(self, trait, reload=True):
339        """
340        columns .. age
341        index .. record index
342        value .. median phenotypic trait value
343        """
344        # TODO check that they exist
345        df = pd.read_csv(self.get_path("phenotypes"), header=[0, 1])
346        df.index.names = ["interval"]
347        df.index = df.index.astype(int)
348        df.columns.names = ["trait", "age_class"]
349        # TODO age_class is str
350        return df.xs(trait, axis=1)
351
352    def get_survival_analysis_TE_observed_interval(self, record_index):
353        """
354        columns .. T, E
355        index .. individual
356        value .. age at event, event (1 .. died, 0 .. alive)
357        """
358        # TODO error with T and E in the record; they are being appended on top
359        assert record_index < len(self.get_path("te")), "Index out of range"
360        data = pd.read_csv(self.get_path("te")[record_index], header=0)
361        data.index.names = ["individual"]
362        return data
363
364    def get_population_size_before_reproduction(self):
365        data = pd.read_csv(self.get_path("popsize_before_reproduction"), header=None)
366        data.index.names = ["steps"]
367        data.columns = ["popsize"]
368        return data
369
370    def get_population_size_after_reproduction(self):
371        data = pd.read_csv(self.get_path("popsize_after_reproduction"), header=None)
372        data.index.names = ["steps"]
373        data.columns = ["popsize"]
374        return data
375
376    def get_egg_number_after_reproduction(self):
377        data = pd.read_csv(self.get_path("eggnum_after_reproduction"), header=None)
378        data.index.names = ["steps"]
379        data.columns = ["number"]
380        return data
381
382    def get_resource_amount_before_scavenging(self):
383        data = pd.read_csv(self.get_path("resources_before_scavenging"), header=None)
384        data.index.names = ["steps"]
385        data.columns = ["resources"]
386        return data
387
388    def get_resource_amount_after_scavenging(self):
389        data = pd.read_csv(self.get_path("resources_after_scavenging"), header=None)
390        data.index.names = ["steps"]
391        data.columns = ["resources"]
392        return data
393
394    ###############
395    # DERIVATIVES #
396    ###############
397
398    def get_lifetime_reproduction(self):
399        survivorship = self.get_surv_observed_interval().cumprod(1)
400        fertility = self.get_fert_observed_interval()
401        return (survivorship * fertility).sum(axis=1)
402
403    def get_average_age_at_reproduction(self):
404        bt = self.get_birth_table_observed_interval()
405        n_offspring = bt.sum(1)
406        average_age_at_reproduction = (bt * bt.columns).sum(1) / n_offspring
407        return average_age_at_reproduction
408
409    #############
410    # UTILITIES #
411    #############
412
413    def _file_exists(self, stem):
414        if self.paths is None:
415            self.set_paths()
416        return stem in self.paths
417
418    def has_ticker_stopped(self):
419        return self.get_ticker().has_stopped()
420
421    def _read_df(self, stem, reload=True):
422        file_read = stem in self.data
423        # TODO Read also files that are not .csv
424
425        if not self._file_exists(stem):
426            logging.error(f"File {self.get_path(stem)} des not exist.")
427        elif (not file_read) or reload:
428            self.data[stem] = pd.read_csv(self.get_path(stem), header=0)
429
430        return self.data.get(stem, pd.DataFrame())
431
432    @staticmethod
433    def _read_json(path):
434        if not path.exists():
435            logging.warning(f"'{path}' does not exist.")
436            return None
437        with open(path, "r") as file_:
438            return json.load(file_)
439
440    def _read_snapshot(self, record_type, record_index):
441        assert record_type in self.get_path("snapshots"), f"No records of '{record_type}' can be found in snapshots"
442        assert record_index < len(self.get_path("snapshots")[record_type]), "Index out of range"
443        return pd.read_feather(self.get_path("snapshots")[record_type][record_index])
444
445    def _read_pickle(self, record_index):
446        assert record_index < len(self.get_path("pickles")), "Index out of range"
447        return Population.load_pickle_from(self.get_path("pickles")[record_index])

Reads and reformats output files so they are available for internal and external use (prepare for export).

Container(basepath)
27    def __init__(self, basepath):
28        self.basepath = pathlib.Path(
29            basepath
30        ).absolute()  # If path to config file is /path/_.yml, then basepath is /path/_
31        self.name = self.basepath.stem
32        self.data = {}
33        # self.set_paths()
34        self.paths = None
35        self.ticker = None
basepath
name
data
paths
ticker
def set_paths(self):
37    def set_paths(self):
38        # TODO smarter way of listing paths; you are capturing te files with number keys e.g. '6': ... /te/6.csv; that's silly
39        # TODO these are repeated elsewhere, e.g. path for ticker
40        self.paths = {
41            path.stem: path for path in self.basepath.glob("**/*") if path.is_file() and path.suffix == ".csv"
42        }
43        self.paths["log"] = self.basepath / "progress.log"
44        self.paths["ticker"] = self.basepath / "ticker.txt"
45        self.paths["simpleprogress"] = self.basepath / "simpleprogress.log"
46        self.paths["output_summary"] = self.basepath / "output_summary.json"
47        self.paths["input_summary"] = self.basepath / "input_summary.json"
48        self.paths["envdriftmap"] = self.basepath / "envdriftmap.csv"
49        self.paths["snapshots"] = {}
50        for kind in ("demography", "phenotypes", "genotypes"):
51            self.paths["snapshots"][kind] = sorted(
52                (self.basepath / "snapshots" / kind).glob("*"),
53                key=lambda path: int(path.stem),
54            )
55        self.paths["pickles"] = sorted(
56            (self.basepath / "pickles").glob("*"),
57            key=lambda path: int(path.stem),
58        )
59        # self.paths["te"] = sorted(
60        #     (self.basepath / "te").glob("*"),
61        #     key=lambda path: int(path.stem),
62        # )
63        self.paths["popsize_before_reproduction"] = self.basepath / "popsize_before_reproduction.csv"
64        self.paths["popsize_after_reproduction"] = self.basepath / "popsize_after_reproduction.csv"
65        self.paths["eggnum_after_reproduction"] = self.basepath / "eggnum_after_reproduction.csv"
66
67        if not self.paths["log"].is_file():
68            logging.error(f"No AEGIS log found at path {self.paths['log']}.")
def get_paths(self):
70    def get_paths(self):
71        if self.paths is None:
72            self.set_paths()
73        return self.paths
def get_path(self, name):
75    def get_path(self, name):
76        if self.paths is None:
77            self.set_paths()
78        return self.paths[name]
def get_record_structure(self):
80    def get_record_structure(self):
81        # TODO
82        return
def report(self):
84    def report(self):
85        """Report present and missing files"""
86        # TODO
87        return

Report present and missing files

def export(self):
89    def export(self):
90        """Export all primary data from the container using general formats"""
91        # TODO
92        return

Export all primary data from the container using general formats

@staticmethod
def stop_process(pid, kind_of_process):
 94    @staticmethod
 95    def stop_process(pid, kind_of_process):
 96        try:
 97            logging.info(f"Terminating {kind_of_process} process with PID {pid}...")
 98            process = psutil.Process(pid)
 99            process.terminate()  # or process.kill()
100            process.wait()  # Optional: Wait for the process to be fully terminated
101            logging.info(f"Process with PID {pid} terminated successfully.")
102        except psutil.NoSuchProcess:
103            logging.warning(f"No process found with PID {pid}.")
104        except psutil.AccessDenied:
105            logging.warning(f"Access denied when trying to terminate the process with PID {pid}.")
106        except Exception as e:
107            logging.error(f"An error occurred: {e}")
def terminate(self):
109    def terminate(self):
110        pid = self.get_input_summary()["pid"]
111        assert pid is not None
112        self.stop_process(pid, "simulation")
113        tpid = self.get_input_summary()["ticker_pid"]
114        assert tpid is not None
115        self.stop_process(tpid, "ticker")
def get_log(self, reload=True):
121    def get_log(self, reload=True):
122        if ("log" not in self.data) or reload:
123            df = pd.read_csv(self.get_path("log"), sep="|")
124            df.columns = [x.strip() for x in df.columns]
125
126            def dhm_inverse(dhm):
127                nums = dhm.replace("`", ":").split(":")
128                return int(nums[0]) * 24 * 60 + int(nums[1]) * 60 + int(nums[2])
129
130            # TODO resolve deprecated function
131            try:
132                df[["ETA", "t1M", "runtime"]].map(dhm_inverse)
133            except:
134                df[["ETA", "t1M", "runtime"]].applymap(dhm_inverse)
135            self.data["log"] = df
136        return self.data["log"]
def get_simple_log(self):
138    def get_simple_log(self):
139        try:
140            with open(self.get_path("simpleprogress"), "r") as file_:
141                text = file_.read()
142                step, steps_per_simulation = text.split("/")
143                return int(step), int(steps_per_simulation)
144        except:
145            logging.error(f"No simpleprogress.log found at {self.get_path('simpleprogress')}")
def get_ticker(self):
147    def get_ticker(self):
148        if self.ticker is None:
149            TICKER_RATE = self.get_final_config()["TICKER_RATE"]
150            self.ticker = Ticker(TICKER_RATE=TICKER_RATE, odir=self.get_path("ticker").parent)
151        return self.ticker
def get_config(self):
153    def get_config(self):
154        if "config" not in self.data:
155            path = self.basepath.parent / f"{self.basepath}.yml"
156            with open(path, "r") as file_:
157                custom_config = yaml.safe_load(file_)
158            # default_config = get_default_parameters()
159            if custom_config is None:
160                custom_config = {}
161            # self.data["config"] = {**default_config, **custom_config}
162            self.data["config"] = custom_config
163        return self.data["config"]
def get_final_config(self):
165    def get_final_config(self):
166        if "final_config" not in self.data:
167            path = self.basepath / "final_config.yml"
168            with open(path, "r") as file_:
169                final_config = yaml.safe_load(file_)
170            if final_config is None:
171                final_config = {}
172            self.data["final_config"] = final_config
173        return self.data["final_config"]
def get_generations_until_interval(self):
175    def get_generations_until_interval(self):
176        """Return Series of number of generations simulated up until interval i"""
177        # TODO beware that snapshots are not timed linearly; there is a bunch of snapshots at the end of the simulation
178        aar = self.get_average_age_at_reproduction()
179        aar.iloc[0] = np.inf  # No time has passed, so no generations yet
180        IR = self.get_final_config()["INTERVAL_RATE"]
181        aar = aar.pipe(lambda s: IR / s).cumsum()
182        return aar

Return Series of number of generations simulated up until interval i

def get_output_summary(self) -> Optional[dict]:
184    def get_output_summary(self) -> Union[dict, None]:
185        path = self.get_path("output_summary")
186        if path.exists():
187            return self._read_json(path)
188        return {}
def get_input_summary(self):
190    def get_input_summary(self):
191        return self._read_json(self.get_path("input_summary"))
def get_envidriftmap(self):
193    def get_envidriftmap(self):
194        path = self.get_path("envdriftmap")
195        if path.exists():
196            return pd.read_csv(self.get_path("envdriftmap"), header=None)
197        return None
def get_phenomap(self):
199    def get_phenomap(self):
200        path = self.get_path("phenomap")
201        if path.exists():
202            return pd.read_csv(self.get_path("phenomap"))
203        return None
def get_birth_table_observed_interval(self, normalize=False):
209    def get_birth_table_observed_interval(self, normalize=False):
210        """
211        Observed data.
212        Number of births (int) per parental age during an interval of length INTERVAL_RATE.
213        columns.name == parental_age (int)
214        index.name == interval (int)
215        """
216        table = self._read_df("age_at_birth")
217        if normalize:
218            table = table.div(table.sum(1), axis=0)
219        table.index.names = ["interval"]
220        table.columns.names = ["parental_age"]
221        table.columns = table.columns.astype(int)
222        return table

Observed data. Number of births (int) per parental age during an interval of length INTERVAL_RATE. columns.name == parental_age (int) index.name == interval (int)

def get_life_table_observed_interval(self, normalize=False):
224    def get_life_table_observed_interval(self, normalize=False):
225        """
226        Observed data.
227        Number of individuals (int) per age class observed during an interval of length INTERVAL_RATE.
228        columns.name == age_class (int)
229        index.name == interval (int)
230        """
231        table = self._read_df("additive_age_structure")
232        table.index.names = ["interval"]
233        table.columns.names = ["age_class"]
234        table.columns = table.columns.astype(int)
235        # NOTE normalize by sum
236        if normalize:
237            table = table.div(table.sum(1), axis=0)
238        return table

Observed data. Number of individuals (int) per age class observed during an interval of length INTERVAL_RATE. columns.name == age_class (int) index.name == interval (int)

def get_life_table_observed_snapshot(self, record_index: int, normalize=False):
240    def get_life_table_observed_snapshot(self, record_index: int, normalize=False):
241        """
242        Observed data. Series.
243        Number of individuals (int) per age class observed at some simulation step captured by the record of index record_index.
244        name == count
245        index.name == age_class
246        """
247        AGE_LIMIT = self.get_final_config()["AGE_LIMIT"]
248        table = (
249            self.get_demography_observed_snapshot(record_index)
250            .ages.value_counts()
251            .reindex(range(AGE_LIMIT), fill_value=0)
252        )
253        table.index.names = ["age_class"]
254        return table

Observed data. Series. Number of individuals (int) per age class observed at some simulation step captured by the record of index record_index. name == count index.name == age_class

def get_death_table_observed_interval(self, normalize=False):
256    def get_death_table_observed_interval(self, normalize=False):
257        """
258        Observed data. Has a MultiIndex.
259        Number of deaths (int) per age class observed during an interval of length INTERVAL_RATE.
260        columns.name == age_class (int)
261        index.names == ["interval", "cause_of_death"] (int, str)
262        """
263        # TODO think about position of axes
264        table = (
265            pd.concat({causeofdeath: self._read_df(f"age_at_{causeofdeath}") for causeofdeath in VALID_CAUSES_OF_DEATH})
266            .swaplevel()
267            .sort_index(level=0)
268        )
269        table.index.names = ["interval", "cause_of_death"]
270        table.columns.names = ["age_class"]
271        table.columns = table.columns.astype(int)
272        return table

Observed data. Has a MultiIndex. Number of deaths (int) per age class observed during an interval of length INTERVAL_RATE. columns.name == age_class (int) index.names == ["interval", "cause_of_death"] (int, str)

def get_surv_observed_interval(self):
278    def get_surv_observed_interval(self):
279        # TODO this is not accurate; this assumes that the population is in an equilibrium, or it only works if the life table is sampling across a long period
280        lt = self.get_life_table_observed_interval()
281        lt = lt.pct_change(axis=1).shift(-1, axis=1).add(1).replace(np.inf, 1)
282        return lt
def get_fert_observed_interval(self):
284    def get_fert_observed_interval(self):
285        lt = self.get_life_table_observed_interval()
286        bt = self.get_birth_table_observed_interval()
287        return bt / lt
def get_genotypes_intrinsic_snapshot(self, record_index):
295    def get_genotypes_intrinsic_snapshot(self, record_index):
296        """
297        columns .. bit index
298        index .. individual index
299        value .. True or False
300        """
301        # TODO let index denote the step at which the snapshot was taken
302        return self._read_snapshot("genotypes", record_index=record_index)

columns .. bit index index .. individual index value .. True or False

def get_phenotype_intrinsic_snapshot(self, trait, record_index):
304    def get_phenotype_intrinsic_snapshot(self, trait, record_index):
305        """
306        columns .. phenotypic trait index
307        index .. individual index
308        value .. phenotypic trait value
309        """
310        # TODO organize by trait
311        # TODO let index denote the step at which the snapshot was taken
312        df = self._read_snapshot("phenotypes", record_index=record_index)
313        # df.columns = df.columns.str.split("_")
314        return df

columns .. phenotypic trait index index .. individual index value .. phenotypic trait value

def get_demography_observed_snapshot(self, record_index):
316    def get_demography_observed_snapshot(self, record_index):
317        """
318        columns .. ages, births, birthdays, generations, sizes, sexes
319        index .. individual index
320        """
321        # TODO let index denote the step at which the snapshot was taken
322        return self._read_snapshot("demography", record_index=record_index)

columns .. ages, births, birthdays, generations, sizes, sexes index .. individual index

def get_genotypes_intrinsic_interval(self, reload=True):
324    def get_genotypes_intrinsic_interval(self, reload=True):
325        """
326        columns .. bit index
327        index .. record index
328        value .. mean bit value
329        """
330        # TODO check that they exist
331        df = pd.read_csv(self.get_path("genotypes"), header=[0, 1], index_col=None)
332        df.index = df.index.astype(int)
333        df.columns = df.columns.set_levels([df.columns.levels[0].astype(int), df.columns.levels[1].astype(int)])
334        df.index.names = ["interval"]
335        df.columns.names = ["bit_index", "ploidy"]
336        return df

columns .. bit index index .. record index value .. mean bit value

def get_phenotype_intrinsic_interval(self, trait, reload=True):
338    def get_phenotype_intrinsic_interval(self, trait, reload=True):
339        """
340        columns .. age
341        index .. record index
342        value .. median phenotypic trait value
343        """
344        # TODO check that they exist
345        df = pd.read_csv(self.get_path("phenotypes"), header=[0, 1])
346        df.index.names = ["interval"]
347        df.index = df.index.astype(int)
348        df.columns.names = ["trait", "age_class"]
349        # TODO age_class is str
350        return df.xs(trait, axis=1)

columns .. age index .. record index value .. median phenotypic trait value

def get_survival_analysis_TE_observed_interval(self, record_index):
352    def get_survival_analysis_TE_observed_interval(self, record_index):
353        """
354        columns .. T, E
355        index .. individual
356        value .. age at event, event (1 .. died, 0 .. alive)
357        """
358        # TODO error with T and E in the record; they are being appended on top
359        assert record_index < len(self.get_path("te")), "Index out of range"
360        data = pd.read_csv(self.get_path("te")[record_index], header=0)
361        data.index.names = ["individual"]
362        return data

columns .. T, E index .. individual value .. age at event, event (1 .. died, 0 .. alive)

def get_population_size_before_reproduction(self):
364    def get_population_size_before_reproduction(self):
365        data = pd.read_csv(self.get_path("popsize_before_reproduction"), header=None)
366        data.index.names = ["steps"]
367        data.columns = ["popsize"]
368        return data
def get_population_size_after_reproduction(self):
370    def get_population_size_after_reproduction(self):
371        data = pd.read_csv(self.get_path("popsize_after_reproduction"), header=None)
372        data.index.names = ["steps"]
373        data.columns = ["popsize"]
374        return data
def get_egg_number_after_reproduction(self):
376    def get_egg_number_after_reproduction(self):
377        data = pd.read_csv(self.get_path("eggnum_after_reproduction"), header=None)
378        data.index.names = ["steps"]
379        data.columns = ["number"]
380        return data
def get_resource_amount_before_scavenging(self):
382    def get_resource_amount_before_scavenging(self):
383        data = pd.read_csv(self.get_path("resources_before_scavenging"), header=None)
384        data.index.names = ["steps"]
385        data.columns = ["resources"]
386        return data
def get_resource_amount_after_scavenging(self):
388    def get_resource_amount_after_scavenging(self):
389        data = pd.read_csv(self.get_path("resources_after_scavenging"), header=None)
390        data.index.names = ["steps"]
391        data.columns = ["resources"]
392        return data
def get_lifetime_reproduction(self):
398    def get_lifetime_reproduction(self):
399        survivorship = self.get_surv_observed_interval().cumprod(1)
400        fertility = self.get_fert_observed_interval()
401        return (survivorship * fertility).sum(axis=1)
def get_average_age_at_reproduction(self):
403    def get_average_age_at_reproduction(self):
404        bt = self.get_birth_table_observed_interval()
405        n_offspring = bt.sum(1)
406        average_age_at_reproduction = (bt * bt.columns).sum(1) / n_offspring
407        return average_age_at_reproduction
def has_ticker_stopped(self):
418    def has_ticker_stopped(self):
419        return self.get_ticker().has_stopped()