aegis_sim.submodels.reproduction.recombination

  1import numpy as np
  2from numba import njit
  3from aegis_sim import variables
  4from aegis_sim.utilities.funcs import profile_time
  5
  6
  7def recombination(genomes, RECOMBINATION_RATE):
  8    """Return recombined chromatids."""
  9
 10    if RECOMBINATION_RATE == 0:
 11        return genomes
 12
 13    # Recombine two chromatids but pass only one;
 14    #   thus double the number of chromatids, recobine,
 15    #   then return only one chromatid from each chromatid pair
 16    genomes = genomes[np.repeat(np.arange(len(genomes)), 2)]
 17
 18    # Flatten loci and bits
 19    flat_genomes = genomes.reshape(len(genomes), 2, -1)
 20
 21    # Get chromatids
 22    chromatid1 = flat_genomes[:, 0]
 23    chromatid2 = flat_genomes[:, 1]
 24
 25    # Make choice array: when to take recombined and when to take original loci
 26    # -1 means synapse; +1 means clear
 27    rr = RECOMBINATION_RATE / 2  # / 2 because you are generating two random vectors (fwd and bkd)
 28    reco_fwd = (variables.rng.random(chromatid1.shape) < rr) * -2 + 1
 29    reco_bkd = (variables.rng.random(chromatid2.shape) < rr) * -2 + 1
 30
 31    # Propagate synapse
 32    reco_fwd_cum = np.cumprod(reco_fwd, axis=1)
 33    reco_bkd_cum = np.cumprod(reco_bkd[:, ::-1], axis=1)[:, ::-1]
 34
 35    # Recombine if both sites recombining
 36    reco_final = (reco_fwd_cum + reco_bkd_cum) == -2
 37
 38    # Choose bits from first or second chromatid
 39    # recombined = np.empty(flat_genomes.shape, bool)
 40    recombined = np.empty(flat_genomes.shape, dtype=np.bool_)
 41    recombined[:, 0] = np.where(reco_final, chromatid2, chromatid1)
 42    recombined[:, 1] = np.where(reco_final, chromatid1, chromatid2)
 43
 44    recombined = recombined.reshape(genomes.shape)
 45    recombined = recombined[::2]  # Look at first comment in the function
 46
 47    return recombined
 48
 49
 50# # Loop version of the vectorized function above
 51# def recombination_via_pairs(genomes, RECOMBINATION_RATE):
 52
 53#     if RECOMBINATION_RATE == 0:
 54#         return genomes
 55
 56#     flat_genomes = genomes.reshape(len(genomes), 2, -1)
 57
 58#     n_sites = flat_genomes.shape[-1]
 59
 60#     n_recombination_sites = np.random.binomial(
 61#         n=n_sites,
 62#         p=RECOMBINATION_RATE,
 63#         size=len(flat_genomes),
 64#     )
 65
 66#     # Produce all random numbers immediately
 67#     chiasmata_list = variables.rng.integers(
 68#         low=1,
 69#         high=n_sites,
 70#         size=(len(n_recombination_sites), max(n_recombination_sites)),
 71#         dtype=np.int32,
 72#     )  # [low, high)
 73
 74#     for i, (chiasmata, n) in enumerate(zip(chiasmata_list, n_recombination_sites)):
 75#         for chiasma in chiasmata[:n]:
 76#             flat_genomes[i, 0, :chiasma], flat_genomes[i, 1, :chiasma] = (
 77#                 flat_genomes[i, 1, :chiasma],
 78#                 flat_genomes[i, 0, :chiasma],
 79#             )
 80
 81#     unflattened_genomes = flat_genomes.reshape(genomes.shape)
 82
 83#     return unflattened_genomes
 84
 85
 86@njit
 87def recombination_via_pairs_numba(flat_genomes, n_recombination_sites, chiasmata_list):
 88    """Optimized recombination using difference-array approach.
 89
 90    Instead of copying genome slices for each chiasma (O(n_sites) per chiasma),
 91    we compute the net swap state per site using a difference array and prefix sum.
 92    Each chiasma at position c toggles sites [0, c), so we track toggle counts
 93    and only swap sites toggled an odd number of times.
 94
 95    This avoids the expensive slice-copy pattern that scaled poorly with genome size.
 96    """
 97    n_individuals = len(flat_genomes)
 98    n_sites = flat_genomes.shape[2]
 99
100    for i in range(n_individuals):
101        n_reco = n_recombination_sites[i]
102        if n_reco == 0:
103            continue
104
105        # Build difference array for toggle counts
106        counts = np.zeros(n_sites + 1, dtype=np.int32)
107        for j in range(n_reco):
108            c = chiasmata_list[i, j]
109            counts[0] += 1
110            if c < n_sites + 1:
111                counts[c] -= 1
112
113        # Prefix sum and swap where toggled odd number of times
114        running = np.int32(0)
115        for k in range(n_sites):
116            running += counts[k]
117            if running % 2 == 1:
118                tmp = flat_genomes[i, 0, k]
119                flat_genomes[i, 0, k] = flat_genomes[i, 1, k]
120                flat_genomes[i, 1, k] = tmp
121
122    return flat_genomes
123
124
125def recombination_via_pairs(genomes, RECOMBINATION_RATE):
126    if RECOMBINATION_RATE == 0:
127        return genomes
128
129    flat_genomes = genomes.reshape(len(genomes), 2, -1).copy()
130    n_sites = flat_genomes.shape[-1]
131    n_recombination_sites = np.random.binomial(n=n_sites, p=RECOMBINATION_RATE, size=len(flat_genomes))
132
133    max_n = max(n_recombination_sites)
134    chiasmata_list = variables.rng.integers(low=1, high=n_sites, size=(len(flat_genomes), max_n), dtype=np.int32)
135
136    flat_genomes = recombination_via_pairs_numba(flat_genomes, n_recombination_sites, chiasmata_list)
137
138    return flat_genomes.reshape(genomes.shape)
def recombination(genomes, RECOMBINATION_RATE):
 8def recombination(genomes, RECOMBINATION_RATE):
 9    """Return recombined chromatids."""
10
11    if RECOMBINATION_RATE == 0:
12        return genomes
13
14    # Recombine two chromatids but pass only one;
15    #   thus double the number of chromatids, recobine,
16    #   then return only one chromatid from each chromatid pair
17    genomes = genomes[np.repeat(np.arange(len(genomes)), 2)]
18
19    # Flatten loci and bits
20    flat_genomes = genomes.reshape(len(genomes), 2, -1)
21
22    # Get chromatids
23    chromatid1 = flat_genomes[:, 0]
24    chromatid2 = flat_genomes[:, 1]
25
26    # Make choice array: when to take recombined and when to take original loci
27    # -1 means synapse; +1 means clear
28    rr = RECOMBINATION_RATE / 2  # / 2 because you are generating two random vectors (fwd and bkd)
29    reco_fwd = (variables.rng.random(chromatid1.shape) < rr) * -2 + 1
30    reco_bkd = (variables.rng.random(chromatid2.shape) < rr) * -2 + 1
31
32    # Propagate synapse
33    reco_fwd_cum = np.cumprod(reco_fwd, axis=1)
34    reco_bkd_cum = np.cumprod(reco_bkd[:, ::-1], axis=1)[:, ::-1]
35
36    # Recombine if both sites recombining
37    reco_final = (reco_fwd_cum + reco_bkd_cum) == -2
38
39    # Choose bits from first or second chromatid
40    # recombined = np.empty(flat_genomes.shape, bool)
41    recombined = np.empty(flat_genomes.shape, dtype=np.bool_)
42    recombined[:, 0] = np.where(reco_final, chromatid2, chromatid1)
43    recombined[:, 1] = np.where(reco_final, chromatid1, chromatid2)
44
45    recombined = recombined.reshape(genomes.shape)
46    recombined = recombined[::2]  # Look at first comment in the function
47
48    return recombined

Return recombined chromatids.

@njit
def recombination_via_pairs_numba(flat_genomes, n_recombination_sites, chiasmata_list):
 87@njit
 88def recombination_via_pairs_numba(flat_genomes, n_recombination_sites, chiasmata_list):
 89    """Optimized recombination using difference-array approach.
 90
 91    Instead of copying genome slices for each chiasma (O(n_sites) per chiasma),
 92    we compute the net swap state per site using a difference array and prefix sum.
 93    Each chiasma at position c toggles sites [0, c), so we track toggle counts
 94    and only swap sites toggled an odd number of times.
 95
 96    This avoids the expensive slice-copy pattern that scaled poorly with genome size.
 97    """
 98    n_individuals = len(flat_genomes)
 99    n_sites = flat_genomes.shape[2]
100
101    for i in range(n_individuals):
102        n_reco = n_recombination_sites[i]
103        if n_reco == 0:
104            continue
105
106        # Build difference array for toggle counts
107        counts = np.zeros(n_sites + 1, dtype=np.int32)
108        for j in range(n_reco):
109            c = chiasmata_list[i, j]
110            counts[0] += 1
111            if c < n_sites + 1:
112                counts[c] -= 1
113
114        # Prefix sum and swap where toggled odd number of times
115        running = np.int32(0)
116        for k in range(n_sites):
117            running += counts[k]
118            if running % 2 == 1:
119                tmp = flat_genomes[i, 0, k]
120                flat_genomes[i, 0, k] = flat_genomes[i, 1, k]
121                flat_genomes[i, 1, k] = tmp
122
123    return flat_genomes

Optimized recombination using difference-array approach.

Instead of copying genome slices for each chiasma (O(n_sites) per chiasma), we compute the net swap state per site using a difference array and prefix sum. Each chiasma at position c toggles sites [0, c), so we track toggle counts and only swap sites toggled an odd number of times.

This avoids the expensive slice-copy pattern that scaled poorly with genome size.

def recombination_via_pairs(genomes, RECOMBINATION_RATE):
126def recombination_via_pairs(genomes, RECOMBINATION_RATE):
127    if RECOMBINATION_RATE == 0:
128        return genomes
129
130    flat_genomes = genomes.reshape(len(genomes), 2, -1).copy()
131    n_sites = flat_genomes.shape[-1]
132    n_recombination_sites = np.random.binomial(n=n_sites, p=RECOMBINATION_RATE, size=len(flat_genomes))
133
134    max_n = max(n_recombination_sites)
135    chiasmata_list = variables.rng.integers(low=1, high=n_sites, size=(len(flat_genomes), max_n), dtype=np.int32)
136
137    flat_genomes = recombination_via_pairs_numba(flat_genomes, n_recombination_sites, chiasmata_list)
138
139    return flat_genomes.reshape(genomes.shape)