Source code for simpunch.spike

"""Tools for generating realistic spikes."""
import os
import random
from glob import glob

import numpy as np

TESTDATA_DIR = os.path.dirname(__file__)
HIT_MAP_DIR = os.path.join(TESTDATA_DIR, "data/hit_maps/")
SPIKE_SCALING_MEAN = 2**16-5_000 # brightness of the average spike
SPIKE_SCALING_STD = SPIKE_SCALING_MEAN * 0.01 # width of the normal distribution for generating spike brightness
SPIKE_FREQUENCY = 40 * 49  # number of spikes in an image

[docs] def read_hit_map(path: str) -> np.ndarray: """Read in a hit map file and make an image of the spikes.""" with open(path) as file: line = file.readline() num_spikes, width, height = map(int, line.split()) table = np.genfromtxt(path, skip_header=1) image = np.zeros((width, height), dtype=float) image[table[:, 0].astype(int), table[:, 1].astype(int)] = table[:, 2] / np.percentile(table[:, 2], 50) return image
[docs] def load_spike_library() -> np.ndarray: """Load all the spike images as a 3D numpy array.""" paths = glob(HIT_MAP_DIR + "*.hmp") return np.array([read_hit_map(path) for path in paths])
[docs] def generate_spike_image( image_shape: (int, int), spike_frequency: int = SPIKE_FREQUENCY, spike_scaling_mean: float = SPIKE_SCALING_MEAN, spike_scaling_std: float = SPIKE_SCALING_STD, max_spike: int = 2**16, patch_size: int = 50, rotate: bool = True, transpose: bool = True, library: np.ndarray | None = None) -> np.ndarray: """Generate a realistic spike image.""" if library is None: library = load_spike_library() spike_target = np.random.choice(image_shape[0] * image_shape[1], spike_frequency) spike_target_x, spike_target_y = np.unravel_index(spike_target, image_shape) spike_source = np.random.choice(np.prod(library.shape), spike_frequency) spike_source_i, spike_source_j, spike_source_k = np.unravel_index(spike_source, library.shape) spike_values = np.random.normal(spike_scaling_mean, spike_scaling_std, spike_frequency) spike_image = np.zeros(image_shape) for spike_value, i, j, k, x, y in zip(spike_values, spike_source_i, spike_source_j, spike_source_k, spike_target_x, spike_target_y, strict=False): patch = library[i, j:j+patch_size, k:k+patch_size] * spike_value pad_amount = tuple([(0, patch_size-s) for s in patch.shape]) patch = np.pad(patch, pad_amount, mode="constant", constant_values=0) if transpose and random.random() < 0.5: # noqa: PLR2004 patch = patch.T if rotate: n = int(random.random() // 0.25) patch = np.rot90(patch, k=n) target_shape = spike_image[x:x+patch_size, y:y+patch_size].shape spike_image[x:x+patch_size, y:y+patch_size] = patch[:target_shape[0], :target_shape[1]] return np.clip(spike_image, 0, max_spike)