Source code for simpunch.level1

"""Generates synthetic level 1 data."""
import copy
import os
from math import floor

import astropy.units as u
import numpy as np
import reproject
import solpolpy
from astropy.io import fits
from astropy.modeling.functional_models import Gaussian2D
from astropy.table import QTable
from astropy.wcs import WCS, DistortionLookupTable
from ndcube import NDCollection, NDCube
from photutils.datasets import make_model_image, make_noise_image
from prefect import get_run_logger, task
from punchbowl.data import (NormalizedMetadata, get_base_file_name,
                            load_ndcube_from_fits, write_ndcube_to_fits)
from punchbowl.data.wcs import (calculate_celestial_wcs_from_helio,
                                calculate_pc_matrix, get_p_angle)

from simpunch.stars import (filter_for_visible_stars, find_catalog_in_image,
                            load_raw_hipparcos_catalog)
from simpunch.util import (fill_metadata_defaults, get_subdirectory,
                           update_spacecraft_location)

CURRENT_DIR = os.path.dirname(__file__)


[docs] def generate_spacecraft_wcs(spacecraft_id: str, rotation_stage: int) -> WCS: """Generate the spacecraft world coordinate system.""" angle_step = 30 if spacecraft_id in ["1", "2", "3"]: if spacecraft_id == "1": angle_wfi = (0 + angle_step * rotation_stage) % 360 * u.deg elif spacecraft_id == "2": angle_wfi = (120 + angle_step * rotation_stage) % 360 * u.deg elif spacecraft_id == "3": angle_wfi = (240 + angle_step * rotation_stage) % 360 * u.deg out_wcs_shape = [2048, 2048] out_wcs = WCS(naxis=2) out_wcs.wcs.crpix = out_wcs_shape[1] / 2 - 0.5, out_wcs_shape[0] / 2 - 0.5 out_wcs.wcs.crval = (24.75 * np.sin(angle_wfi), 24.75 * np.cos(angle_wfi)) out_wcs.wcs.cdelt = 0.02, 0.02 out_wcs.wcs.ctype = "HPLN-AZP", "HPLT-AZP" out_wcs.wcs.cunit = "deg", "deg" out_wcs.wcs.pc = calculate_pc_matrix(360 * u.deg - angle_wfi, out_wcs.wcs.cdelt) out_wcs.wcs.set_pv([(2, 1, 0.0)]) elif spacecraft_id == "4": angle_nfi = (0 + angle_step * rotation_stage) % 360 * u.deg out_wcs_shape = [2048, 2048] out_wcs = WCS(naxis=2) out_wcs.wcs.crpix = out_wcs_shape[1] / 2 + 0.5, out_wcs_shape[0] / 2 + 0.5 out_wcs.wcs.crval = 0, 0 out_wcs.wcs.cdelt = 30 / 3600, 30 / 3600 out_wcs.wcs.pc = calculate_pc_matrix(angle_nfi, out_wcs.wcs.cdelt) out_wcs.wcs.ctype = "HPLN-ARC", "HPLT-ARC" out_wcs.wcs.cunit = "deg", "deg" else: msg = "Invalid spacecraft_id." raise ValueError(msg) return out_wcs
[docs] def deproject_polar(input_data: NDCube, output_wcs: WCS, adaptive_reprojection: bool = False) -> tuple[NDCube, WCS]: """Deproject a polarized image.""" reconstructed_wcs = WCS(naxis=3) reconstructed_wcs.wcs.ctype = input_data.wcs.wcs.ctype reconstructed_wcs.wcs.cunit = input_data.wcs.wcs.cunit reconstructed_wcs.wcs.cdelt = input_data.wcs.wcs.cdelt reconstructed_wcs.wcs.crpix = input_data.wcs.wcs.crpix reconstructed_wcs.wcs.crval = input_data.wcs.wcs.crval reconstructed_wcs.wcs.pc = input_data.wcs.wcs.pc reconstructed_wcs = calculate_celestial_wcs_from_helio(reconstructed_wcs, input_data.meta.astropy_time, input_data.data.shape) reconstructed_wcs = reconstructed_wcs.dropaxis(2) output_wcs_helio = copy.deepcopy(output_wcs) output_wcs = calculate_celestial_wcs_from_helio(output_wcs, input_data.meta.astropy_time, input_data.data.shape).dropaxis(2) if adaptive_reprojection: reprojected_data = reproject.reproject_adaptive((input_data.data, reconstructed_wcs), output_wcs, (2048, 2048), roundtrip_coords=False, return_footprint=False, kernel="Gaussian", boundary_mode="ignore") else: reprojected_data = reproject.reproject_interp((input_data.data, reconstructed_wcs), output_wcs, (2048, 2048), roundtrip_coords=False, return_footprint=False) reprojected_data[np.isnan(reprojected_data)] = 0 return NDCube(data=reprojected_data, wcs=output_wcs_helio, meta=input_data.meta), output_wcs_helio
[docs] def deproject_clear(input_data: NDCube, output_wcs: WCS, adaptive_reprojection: bool = False) -> tuple[NDCube, WCS]: """Deproject a clear image.""" reconstructed_wcs = WCS(naxis=2) reconstructed_wcs.wcs.ctype = input_data.wcs.wcs.ctype reconstructed_wcs.wcs.cunit = input_data.wcs.wcs.cunit reconstructed_wcs.wcs.cdelt = input_data.wcs.wcs.cdelt reconstructed_wcs.wcs.crpix = input_data.wcs.wcs.crpix reconstructed_wcs.wcs.crval = input_data.wcs.wcs.crval reconstructed_wcs.wcs.pc = input_data.wcs.wcs.pc reconstructed_wcs = calculate_celestial_wcs_from_helio(reconstructed_wcs, input_data.meta.astropy_time, input_data.data.shape) output_wcs_helio = copy.deepcopy(output_wcs) output_wcs = calculate_celestial_wcs_from_helio(output_wcs, input_data.meta.astropy_time, input_data.data.shape) if adaptive_reprojection: reprojected_data= reproject.reproject_adaptive((input_data.data, reconstructed_wcs), output_wcs, (2048, 2048), roundtrip_coords=False, return_footprint=False, kernel="Gaussian", boundary_mode="ignore") else: reprojected_data = reproject.reproject_interp((input_data.data, reconstructed_wcs), output_wcs, (2048, 2048), roundtrip_coords=False, return_footprint=False) reprojected_data[np.isnan(reprojected_data)] = 0 return NDCube(data=reprojected_data, wcs=output_wcs_helio, meta=input_data.meta), output_wcs_helio
[docs] def mark_quality(input_data: NDCube) -> NDCube: """Mark the quality of image patches.""" return input_data
[docs] def remix_polarization(input_data: NDCube) -> NDCube: """Remix polarization from (M, Z, P) to (P1, P2, P3) using solpolpy.""" # Unpack data into a NDCollection object w = WCS(naxis=2) data_collection = NDCollection([("M", NDCube(data=input_data.data[0], wcs=w, meta={})), ("Z", NDCube(data=input_data.data[1], wcs=w, meta={})), ("P", NDCube(data=input_data.data[2], wcs=w, meta={}))]) data_collection["M"].meta["POLAR"] = -60. * u.degree data_collection["Z"].meta["POLAR"] = 0. * u.degree data_collection["P"].meta["POLAR"] = 60. * u.degree # TODO - Remember that this needs to be the instrument frame MZP, not the mosaic frame resolved_data_collection = solpolpy.resolve(data_collection, "npol", out_angles=[-60, 0, 60] * u.deg, imax_effect=False) # Repack data data_list = [] wcs_list = [] uncertainty_list = [] for key in resolved_data_collection: data_list.append(resolved_data_collection[key].data) wcs_list.append(resolved_data_collection[key].wcs) uncertainty_list.append(resolved_data_collection[key].uncertainty) # Remove alpha channel if present if "alpha" in resolved_data_collection: data_list.pop() wcs_list.pop() uncertainty_list.pop() # Repack into an NDCube object new_data = np.stack(data_list, axis=0) if uncertainty_list[0] is not None: # noqa: SIM108 new_uncertainty = np.stack(uncertainty_list, axis=0) else: new_uncertainty = None new_wcs = input_data.wcs.copy() return NDCube(data=new_data, wcs=new_wcs, uncertainty=new_uncertainty, meta=input_data.meta)
# TODO - add scaling factor
[docs] def add_distortion(input_data: NDCube) -> NDCube: """Add a distortion model to the WCS.""" filename_distortion = ( os.path.join(CURRENT_DIR, "data/distortion_NFI.fits") if input_data.meta["OBSCODE"].value == "4" else os.path.join(CURRENT_DIR, "data/distortion_WFI.fits") ) with fits.open(filename_distortion) as hdul: err_x = hdul[1].data err_y = hdul[2].data crpix = err_x.shape[1] / 2 + 0.5, err_x.shape[0] / 2 + 0.5 coord = input_data.wcs.pixel_to_world(crpix[0], crpix[1]) crval = (coord.Tx.to(u.deg).value, coord.Ty.to(u.deg).value) cdelt = (input_data.wcs.wcs.cdelt[0] * input_data.wcs.wcs.cdelt[0] / err_x.shape[1], input_data.wcs.wcs.cdelt[0] * input_data.wcs.wcs.cdelt[0] / err_x.shape[1]) cpdis1 = DistortionLookupTable( -err_x.astype(np.float32), crpix, crval, cdelt, ) cpdis2 = DistortionLookupTable( -err_y.astype(np.float32), crpix, crval, cdelt, ) input_data.wcs.cpdis1 = cpdis1 input_data.wcs.cpdis2 = cpdis2 return input_data
[docs] def generate_starfield(wcs: WCS, img_shape: (int, int), fwhm: float, wcs_mode: str = "all", mag_set: float = 0, flux_set: float = 500_000, noise_mean: float | None = 25.0, noise_std: float | None = 5.0, dimmest_magnitude: float = 8) -> (np.ndarray, QTable): """Generate a realistic starfield.""" sigma = fwhm / 2.355 catalog = load_raw_hipparcos_catalog() filtered_catalog = filter_for_visible_stars(catalog, dimmest_magnitude=dimmest_magnitude) stars = find_catalog_in_image(filtered_catalog, wcs, img_shape, mode=wcs_mode) star_mags = stars["Vmag"] sources = QTable() sources["x_mean"] = stars["x_pix"] sources["y_mean"] = stars["y_pix"] sources["x_stddev"] = np.ones(len(stars)) * sigma sources["y_stddev"] = np.ones(len(stars)) * sigma sources["amplitude"] = flux_set * np.power(10, -0.4 * (star_mags - mag_set)) sources["theta"] = np.zeros(len(stars)) model = Gaussian2D() model_shape = (25, 25) fake_image = make_model_image(img_shape, model, sources, model_shape=model_shape, x_name="x_mean", y_name="y_mean") if noise_mean is not None and noise_std is not None: # we only add noise if it's specified fake_image += make_noise_image(img_shape, "gaussian", mean=noise_mean, stddev=noise_std) return fake_image, sources
[docs] def add_starfield_polarized(input_collection: NDCollection, polfactor: tuple = (0.2, 0.3, 0.5)) -> NDCollection: """Add synthetic polarized starfield.""" input_data = input_collection["Z"] wcs_stellar_input = calculate_celestial_wcs_from_helio(input_data.wcs, input_data.meta.astropy_time, input_data.data.shape) starfield, stars = generate_starfield(wcs_stellar_input, input_data.data.shape, flux_set=100*2.0384547E-9, fwhm=3, dimmest_magnitude=12, noise_mean=None, noise_std=None) starfield_data = np.zeros(input_data.data.shape) starfield_data[:, :] = starfield * (np.logical_not(np.isclose(input_data.data, 0, atol=1E-18))) # Converting the input data polarization to celestial basis mzp_angles = ([input_cube.meta["POLAR"].value for label, input_cube in input_collection.items() if label != "alpha"]) * u.degree cel_north_off = get_p_angle(time=input_collection["Z"].meta["DATE-OBS"].value) new_angles = (mzp_angles + cel_north_off).value * u.degree valid_keys = [key for key in input_collection if key != "alpha"] meta_a = dict(NormalizedMetadata.to_fits_header(input_collection[valid_keys[0]].meta, wcs=input_collection[valid_keys[0]].wcs)) meta_b = dict(NormalizedMetadata.to_fits_header(input_collection[valid_keys[1]].meta, wcs=input_collection[valid_keys[1]].wcs)) meta_c = dict(NormalizedMetadata.to_fits_header(input_collection[valid_keys[2]].meta, wcs=input_collection[valid_keys[2]].wcs)) meta_a["POLAR"] = meta_a["POLAR"] * u.degree meta_b["POLAR"] = meta_b["POLAR"] * u.degree meta_c["POLAR"] = meta_c["POLAR"] * u.degree data_collection = NDCollection( [(str(valid_keys[0]), NDCube(data=input_collection[valid_keys[0]].data, meta=meta_a, wcs=input_collection[valid_keys[0]].wcs)), (str(valid_keys[1]), NDCube(data=input_collection[valid_keys[1]].data, meta=meta_b, wcs=input_collection[valid_keys[1]].wcs)), (str(valid_keys[2]), NDCube(data=input_collection[valid_keys[2]].data, meta=meta_c, wcs=input_collection[valid_keys[2]].wcs))], aligned_axes="all") input_data_cel = solpolpy.resolve(data_collection, "npol", reference_angle=0 * u.degree, out_angles=new_angles) valid_keys = [key for key in input_data_cel if key != "alpha"] dummy_polarmaps = [] for k, _ in enumerate(valid_keys): # Generate an all-sky polarization map for each of the three polarization states dummy_polarmaps.append(generate_dummy_polarization(pol_factor=polfactor[k])) polarmap_wcs = dummy_polarmaps[0].wcs dummy_polarmaps = [d.data for d in dummy_polarmaps] # Reproject the polarization maps in one go into the frame of the input image polar_rois = reproject.reproject_adaptive( (np.array(dummy_polarmaps), polarmap_wcs), wcs_stellar_input, input_data.data.shape, roundtrip_coords=False, return_footprint=False, x_cyclic=True, conserve_flux=True, center_jacobian=True, despike_jacobian=True) # Apply the polarization maps to the starfield and add them to the data for k, key in enumerate(valid_keys): input_data_cel[key].data[...] = input_data_cel[key].data + polar_rois[k] * starfield_data mzp_data_instru = solpolpy.resolve(input_data_cel, "mzpinstru", reference_angle=0 * u.degree) # Instrument MZP valid_keys = [key for key in mzp_data_instru if key != "alpha"] out_meta = {"M": copy.deepcopy(input_collection["M"].meta), "Z": copy.deepcopy(input_collection["Z"].meta), "P": copy.deepcopy(input_collection["P"].meta)} for out_pol, meta_item in out_meta.items(): for key, kind in zip(["POLAR", "POLARREF", "POLAROFF"], [int, str, float], strict=False): if isinstance(mzp_data_instru[out_pol].meta[key], u.Quantity): meta_item[key] = kind(mzp_data_instru[out_pol].meta[key].value) else: meta_item[key] = kind(mzp_data_instru[out_pol].meta[key]) return NDCollection( [(str(key), NDCube(data=mzp_data_instru[key].data, meta=out_meta[key], wcs=mzp_data_instru[key].wcs)) for key in valid_keys], aligned_axes="all")
[docs] def add_starfield_clear(input_data: NDCube) -> NDCube: """Add synthetic starfield.""" wcs_stellar_input = calculate_celestial_wcs_from_helio(input_data.wcs, input_data.meta.astropy_time, input_data.data.shape) starfield, stars = generate_starfield(wcs_stellar_input, input_data.data[:, :].shape, flux_set=0.1*2.0384547E-9, fwhm=3, dimmest_magnitude=12, noise_mean=None, noise_std=None) starfield_data = np.zeros(input_data.data.shape) starfield_data[:, :] = starfield * (np.logical_not(np.isclose(input_data.data[:, :], 0, atol=1E-18))) input_data.data[...] = input_data.data[...] + starfield_data return input_data
[docs] def generate_dummy_polarization(map_scale: float = 0.225, pol_factor: float = 0.5) -> NDCube: """Create a synthetic polarization map.""" shape = [floor(180 / map_scale), floor(360 / map_scale)] xcoord = np.linspace(-pol_factor, pol_factor, shape[1]) ycoord = np.linspace(-pol_factor, pol_factor, shape[0]) xin, yin = np.meshgrid(xcoord, ycoord) zin = pol_factor - (xin ** 2 + yin ** 2) wcs_sky = WCS(naxis=2) wcs_sky.wcs.crpix = [shape[1] / 2 + .5, shape[0] / 2 + .5] wcs_sky.wcs.cdelt = np.array([map_scale, map_scale]) wcs_sky.wcs.crval = [180.0, 0.0] wcs_sky.wcs.ctype = ["RA---CAR", "DEC--CAR"] wcs_sky.wcs.cunit = "deg", "deg" return NDCube(data=zin, wcs=wcs_sky)
[docs] @task def generate_l1_pmzp(input_file: str, path_output: str, rotation_stage: int, spacecraft_id: str) -> list[str]: """Generate level 1 polarized synthetic data.""" logger = get_run_logger() input_pdata = load_ndcube_from_fits(input_file) logger.info(f"Read input file {input_file}") # Define the output data product product_code = "PM" + spacecraft_id product_level = "1" output_meta = NormalizedMetadata.load_template(product_code, product_level) fill_metadata_defaults(output_meta) output_meta["DATE-OBS"] = input_pdata.meta["DATE-OBS"].value output_meta["DESCRPTN"] = "Simulated " + output_meta["DESCRPTN"].value output_meta["TITLE"] = "Simulated " + output_meta["TITLE"].value output_wcs = generate_spacecraft_wcs(spacecraft_id, rotation_stage) # Synchronize overlapping metadata keys output_header = output_meta.to_fits_header(output_wcs) for key in output_header: if (key in input_pdata.meta) and output_header[key] == "" and key not in ("COMMENT", "HISTORY"): output_meta[key].value = input_pdata.meta[key].value output_data, output_wcs = deproject_polar(input_pdata, output_wcs) logger.info("Deprojected") output_data = mark_quality(output_data) logger.info("Quality marked") output_data = remix_polarization(output_data) logger.info("Polarization mixed") output_mmeta = copy.deepcopy(output_meta) output_zmeta = copy.deepcopy(output_meta) output_pmeta = copy.deepcopy(output_meta) output_mwcs = copy.deepcopy(output_wcs) output_zwcs = copy.deepcopy(output_wcs) output_pwcs = copy.deepcopy(output_wcs) output_mdata = NDCube(data=output_data.data[0, :, :].astype(np.float32), wcs=output_mwcs, meta=output_mmeta) output_zdata = NDCube(data=output_data.data[1, :, :].astype(np.float32), wcs=output_zwcs, meta=output_zmeta) output_pdata = NDCube(data=output_data.data[2, :, :].astype(np.float32), wcs=output_pwcs, meta=output_pmeta) output_mdata.meta["TYPECODE"] = "PM" output_zdata.meta["TYPECODE"] = "PZ" output_pdata.meta["TYPECODE"] = "PP" output_mdata.meta["POLAR"] = -60 output_zdata.meta["POLAR"] = 0 output_pdata.meta["POLAR"] = 60 # Add distortion # output_mdata = add_distortion(output_mdata) # noqa: ERA001 # output_zdata = add_distortion(output_zdata) # noqa: ERA001 # output_pdata = add_distortion(output_pdata) # noqa: ERA001 # logger.info("Distortion added") # noqa: ERA001 output_collection = NDCollection( [("M", output_mdata), ("Z", output_zdata), ("P", output_pdata)], aligned_axes="all") output_mzp = add_starfield_polarized(output_collection) logger.info("Starfield added") output_mdata = output_mzp["M"] output_zdata = output_mzp["Z"] output_pdata = output_mzp["P"] output_pdata = update_spacecraft_location(output_pdata, output_pdata.meta.astropy_time) output_mdata = update_spacecraft_location(output_mdata, output_mdata.meta.astropy_time) output_zdata = update_spacecraft_location(output_zdata, output_zdata.meta.astropy_time) # Write out paths = [] path = os.path.join(path_output, get_subdirectory(output_mdata), get_base_file_name(output_mdata) + ".fits") os.makedirs(os.path.dirname(path), exist_ok=True) paths.append(path) logger.info(f"Writing data to {path}") write_ndcube_to_fits(output_mdata, path) path = os.path.join(path_output, get_subdirectory(output_zdata), get_base_file_name(output_zdata) + ".fits") os.makedirs(os.path.dirname(path), exist_ok=True) paths.append(path) logger.info(f"Writing data to {path}") write_ndcube_to_fits(output_zdata, path) path = os.path.join(path_output, get_subdirectory(output_pdata), get_base_file_name(output_pdata) + ".fits") os.makedirs(os.path.dirname(path), exist_ok=True) paths.append(path) logger.info(f"Writing data to {path}") write_ndcube_to_fits(output_pdata, path) logger.info("All data written") return paths
[docs] @task def generate_l1_cr(input_file: str, path_output: str, rotation_stage: int, spacecraft_id: str) -> str: """Generate level 1 clear synthetic data.""" logger = get_run_logger() input_pdata = load_ndcube_from_fits(input_file) logger.info(f"Read input file {input_file}") # Define the output data product product_code = "CR" + spacecraft_id product_level = "1" output_meta = NormalizedMetadata.load_template(product_code, product_level) fill_metadata_defaults(output_meta) output_meta["DATE-OBS"] = input_pdata.meta["DATE-OBS"].value output_wcs = generate_spacecraft_wcs(spacecraft_id, rotation_stage) # Synchronize overlapping metadata keys output_header = output_meta.to_fits_header(output_wcs) for key in output_header: if (key in input_pdata.meta) and output_header[key] == "" and key not in ("COMMENT", "HISTORY"): output_meta[key].value = input_pdata.meta[key].value # Deproject to spacecraft frame output_data, output_wcs = deproject_clear(input_pdata, output_wcs) logger.info("Deprojected") # Quality marking output_data = mark_quality(output_data) logger.info("Quality marked") # output_data = add_distortion(output_data) # noqa: ERA001 # logger.info("Distortion added") # noqa: ERA001 output_data = add_starfield_clear(output_data) logger.info("Starfield added") output_cmeta = copy.deepcopy(output_meta) output_cwcs = copy.deepcopy(output_wcs) # Package into NDCube objects output_cdata = NDCube(data=output_data.data[:, :].astype(np.float32), wcs=output_cwcs, meta=output_cmeta) output_cdata.meta["TYPECODE"] = "CR" output_cdata.meta["DESCRPTN"] = "Simulated" + output_cdata.meta["DESCRPTN"].value output_cdata.meta["TITLE"] = "Simulated " + output_cdata.meta["TITLE"].value output_cdata.meta["POLAR"] = 9999 output_cdata = update_spacecraft_location(output_cdata, output_cdata.meta.astropy_time) # Write out out_path = os.path.join(path_output, get_subdirectory(output_cdata), get_base_file_name(output_cdata) + ".fits") os.makedirs(os.path.dirname(out_path), exist_ok=True) logger.info(f"Writing data to {out_path}") write_ndcube_to_fits(output_cdata, out_path) logger.info("Data written") return out_path