Source code for sndata.utils

#!/usr/bin/env python3
# -*- coding: UTF-8 -*-

"""The ``utils`` module provides an assorted collection of general utilities
used when building data access classes for a given survey / data release.
"""

import functools
import os
import sys
import tarfile
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import IO, Union

import numpy as np
import requests
import sncosmo
from astropy.coordinates import Angle
from pytz import utc
from tqdm import tqdm

from .exceptions import NoDownloadedData


[docs]def hourangle_to_degrees( rah: float, ram: float, ras: float, dec_sign: str, decd: float, decm: float, decs: float) -> (float, float): """Convert from hour angle to degrees Args: rah: RA hours ram: RA arcminutes ras: RA arcseconds dec_sign: Sign of the declination ('+' or '-') decd: Dec degrees decm: Dec arcmin decs: Dec arcsec """ # Convert Right Ascension ra = Angle((rah, ram, ras), unit='hourangle').to('deg').value # Convert Declination sign = -1 if dec_sign == '-' else 1 dec = ( sign * decd + # Already in degrees decm / 60 + # arcmin to degrees decs / 60 / 60 # arcesc to degrees ) return ra, dec
[docs]def find_data_dir(survey_abbrev: str, release: str) -> Path: """Determine the directory where data files are stored for a data release If the directory does not exist, create it. Args: survey_abbrev: Abbreviation of the survey to load data for (e.g., CSP) release: Name of the data release from the survey (e.g., DR1) Returns: The path of the directory where """ # Enforce the use of lowercase file names safe_survey = survey_abbrev.lower().replace(' ', '_') safe_release = release.lower().replace(' ', '_') # Default to using data directory specified in the environment if 'SNDATA_DIR' in os.environ: base_dir = Path(os.environ['SNDATA_DIR']).resolve() else: base_dir = Path(__file__).resolve().parent / 'data' data_dir = base_dir / safe_survey / safe_release return data_dir
[docs]def lru_copy_cache(maxsize: int = 128, typed: bool = False, copy: bool = True): """Decorator to cache the return of a function Similar to ``functools.lru_cache``, but allows a copy of the cached value to be returned, thus preventing mutation of the cache. Args: maxsize: Maximum size of the cache typed: Cache objects of different types separately (Default: False) copy: Return a copy of the cached item (Default: True) Returns: A decorator """ if not copy: # Return the normal function cache return functools.lru_cache(maxsize, typed) # noinspection PyMissingOrEmptyDocstring def decorator(f): cached_func = functools.lru_cache(maxsize, typed)(f) @functools.wraps(f) def wrapper(*args, **kwargs): return deepcopy(cached_func(*args, **kwargs)) return wrapper return decorator
[docs]def build_pbar(data: iter, verbose: Union[bool, dict]): """Cast an iterable into a progress bar If verbose is False, return ``data`` unchanged. Args: data: An iterable object verbose: Arguments for tqdm.tqdm """ if isinstance(verbose, dict): iter_data = tqdm(data, **verbose) elif verbose: iter_data = tqdm(data) else: iter_data = data return iter_data
@np.vectorize def convert_to_jd(date: float, format: str) -> float: """Convert dates into JD Can convert the Snoopy, MJD, or UT time standards. Args: date: Time stamp value format: Either ``snpy``, ``mjd``, or ``ut`` Returns: The time value in JD format """ snoopy_offset = 53000 # Conversion from Snoopy to MJD mjd_offset = 2400000.5 # Conversion from MJD to JD if format.lower() == 'snpy': return date + snoopy_offset + mjd_offset elif format.lower() == 'mjd': return date + mjd_offset elif format.lower() == 'ut': # Break date down into year, month, and days str_date = str(date) year = int(str_date[:4]) month = int(str_date[4:6]) day = int(str_date[6:8]) fractional_days = float(str_date[8:]) # Convert fractional days into minutes and seconds hours_in_day = 24 min_in_hour = 60 sec_in_min = 60 microsec_in_sec = 1e+6 hours = fractional_days * hours_in_day minutes = (hours * min_in_hour) - (int(hours) * min_in_hour) seconds = (minutes * sec_in_min) - (int(minutes) * sec_in_min) microsec = (seconds * microsec_in_sec) - (int(seconds) * microsec_in_sec) # ``toordinal`` returns the number of days since December 31, 1 BC # We add 1721424.5 to rescale the result to January 1, 4713 BC at 12:00 (i.e. to JD) date = datetime(year, month, day, int(hours), int(minutes), int(seconds), int(microsec), tzinfo=utc) return date.toordinal() + 1721424.5 raise NotImplementedError(f'Cannot convert format: {format}')
[docs]def download_file( url: str, destination: Union[str, Path, IO] = None, force: bool = False, timeout: float = 15, verbose: bool = True): """Download content from a url to a file If ``destination`` is a path but already exists, skip the download unless ``force`` is also ``True``. Args: url: URL of the file to download destination: Path or file object to download to force: Re-Download locally available data (Default: False) timeout: Seconds before raising timeout error (Default: 15) verbose: Print status to stdout """ destination_is_path = isinstance(destination, (str, Path)) if destination_is_path: path = Path(destination) if (not force) and path.exists(): return path.parent.mkdir(exist_ok=True, parents=True) destination = path.open('wb') if verbose: tqdm.write(f'Fetching {url}', file=sys.stdout) response = requests.get(url, stream=True, timeout=timeout) total = int(response.headers.get('content-length', 0)) chunk_size = 1024 with tqdm(total=total, unit='B', unit_scale=True, unit_divisor=chunk_size, file=sys.stdout) as pbar: for data in response.iter_content(chunk_size=chunk_size): pbar.update(destination.write(data)) else: response = requests.get(url, timeout=timeout) response.raise_for_status() destination.write(response.content) destination.write(response.content) if destination_is_path: destination.close()
[docs]def download_tar( url: str, out_dir: str, mode: str = 'r:gz', force: bool = False, timeout: float = 15, skip_exists: str = None ): """Download and unzip a .tar.gz file to a given output directory Args: url: URL of the file to download out_dir: The directory to unzip file contents to mode: Compression mode (Default: r:gz) force: Re-Download locally available data (Default: False) timeout: Seconds before raising timeout error (Default: 15) skip_exists: Optionally skip the download if given path exists """ out_dir = Path(out_dir) # Skip download if file already exists or url unavailable if skip_exists and Path(skip_exists).exists() and not force: return # Download data to file and decompress with NamedTemporaryFile() as temp_file: download_file(url, destination=temp_file, timeout=timeout) # Writing to the file moves us to the end of the file # We move back to the beginning so we can decompress the data temp_file.seek(0) out_dir.mkdir(parents=True, exist_ok=True) with tarfile.open(fileobj=temp_file, mode=mode) as data_archive: for ffile in data_archive: try: data_archive.extract(ffile, path=out_dir) except IOError: # If output path already exists, delete it and try again (out_dir / ffile.name).unlink() data_archive.extract(ffile, path=out_dir)
[docs]def require_data_path(*data_dirs: Path): """Raise NoDownloadedData exception if given paths don't exist Args: *data_dirs: Path objects to check exists """ for data_dir in data_dirs: if not data_dir.exists(): raise NoDownloadedData()
[docs]def read_vizier_table_descriptions(readme_path: Union[Path, str]): """Returns the table descriptions from a vizier readme file Args: readme_path: Path of the file to read Returns: A dictionary {<Table number (int)>: <Table description (str)>} """ table_descriptions = dict() with open(readme_path) as ofile: # Skip lines before table summary line = next(ofile) while line.strip() != 'File Summary:': line = next(ofile) # Skip table header for _ in range(5): line = next(ofile) # Iterate until end of table marker while not line.startswith('---'): line_list = line.split() table_num = line_list[0].lstrip('table').rstrip('.dat') if table_num.isdigit(): table_num = int(table_num) table_desc = ' '.join(line_list[3:]) line = next(ofile) # Keep building description for multiline descriptions while line.startswith(' '): table_desc += ' ' + line.strip() line = next(ofile) table_descriptions[table_num] = table_desc return table_descriptions
[docs]def register_filter_file(file_path: str, filter_name: str, force: bool = False): """Registers filter profiles with sncosmo if not already registered Assumes the file at ``file_path`` is a two column, white space delimited ascii table. Args: file_path: Path of ascii table with wavelength (Ang) and transmission filter_name: The name of the registered filter. force: Whether to re-register a band if already registered """ # Get set of registered builtin and custom band passes available_bands = set( k[0] for k in sncosmo.bandpasses._BANDPASSES._loaders) available_bands.update( k[0] for k in sncosmo.bandpasses._BANDPASSES._instances) # Register the new bandpass if filter_name not in available_bands: wave, trans = np.genfromtxt(file_path).T is_good_data = ~np.isnan(wave) & ~np.isnan(trans) band = sncosmo.Bandpass(wave[is_good_data], trans[is_good_data]) band.name = filter_name sncosmo.register(band, force=force)