Source code for sndata.base_classes

#!/usr/bin/env python3
# -*- coding: UTF-8 -*-

"""The ``base_classes`` module defines parent classes used by the data access
API to define basic data handling and to enforce a consistent user interface.
For an example on how to use these classes to create custom data access module
for a new survey / data release, see the :ref:`CustomClasses` section of the

import functools
import shutil
import warnings
from typing import List
from typing import Union

import numpy as np
from import ascii
from astropy.table import Table

from . import utils
from .exceptions import InvalidObjId, InvalidTableId

# Define short hand type for Ids of Vizier Tables
VizierTableId = Union[int, str]

def ignore_warnings_wrapper(func: callable) -> callable:
    """Ignores warnings issued by the wrapped function call"""

    def inner(*args, **kwargs):
        with warnings.catch_warnings():
            return func(*args, **kwargs)

    return inner

[docs]class DefaultParser: """Prebuilt data parsing tools for Vizier tables and photometric filters For more information see the :ref:`CustomClasses` section of the docs. """
[docs] def _get_available_tables(self) -> List[VizierTableId]: """Default backend functionality of ``get_available_tables`` function""" # Find available tables - assume standard Vizier naming scheme # This includes assuming lowercase file names table_nums = [] for f in self._table_dir.rglob('table*.dat'): table_number = f.stem.lstrip('table') try: table_number = int(table_number) except ValueError: pass table_nums.append(table_number) return sorted(table_nums, key=str)
[docs] def _load_table(self, table_id: VizierTableId) -> Table: """Default backend functionality of ``load_table`` function""" readme_path = self._table_dir / 'ReadMe' table_path = self._table_dir / f'table{table_id}.dat' # Read data from file and add meta data from the readme data =, format='cds', readme=str(readme_path)) description = utils.read_vizier_table_descriptions(readme_path)[table_id] data.meta['description'] = description return data
[docs] def _register_filters(self, force: bool = False): """Default backend functionality of ``register_filters`` function""" bandpass_data = zip(self._filter_file_names, self.band_names) for _file_name, _band_name in bandpass_data: filter_path = self._filter_dir / _file_name utils.register_filter_file(filter_path, _band_name, force=force)
[docs]class SpectroscopicRelease: """Generic representation of a spectroscopic data release This class is a template designed to enforce a consistent user interface and requires child classes to fill in incomplete functionality. """ # General metadata publications = tuple() ads_url = None survey_name = None survey_abbrev = None release = None survey_url = None data_type = 'spectroscopic' def __init__(self, survey_abbrev: str = None, release: str = None): """Represent Vizier data downloaded on the local machine Args: survey_abbrev: Abbreviation of the survey to load data for (e.g., CSP) release: Name of the data release from the survey (e.g., DR1) """ err_msg = '``{}`` must either be passed at initialization or set as attribute' if survey_abbrev is None and not hasattr(self, 'survey_abbrev'): raise ValueError(err_msg.format('survey_abbrev')) if release is None and not hasattr(self, 'release'): raise ValueError(err_msg.format('release')) self.survey_abbrev = survey_abbrev if survey_abbrev else self.survey_abbrev self.release = release if release else self.release self._data_dir = utils.find_data_dir(self.survey_abbrev, self.release) self._table_dir = self._data_dir / 'tables'
[docs] def get_available_tables(self) -> List[VizierTableId]: """Get Ids for available vizier tables published by this data release""" # Raise error if data is not downloaded utils.require_data_path(self._data_dir) return self._get_available_tables()
[docs] @utils.lru_copy_cache(maxsize=None) @ignore_warnings_wrapper def load_table(self, table_id: VizierTableId) -> Table: """Return a Vizier table published by this data release Args: table_id: The published table number or table name """ # Raise error if data is not downloaded if table_id not in self.get_available_tables(): raise InvalidTableId(f'Table {table_id} is not available.') return self._load_table(table_id)
[docs] def get_available_ids(self) -> List[str]: """Return a list of target object IDs for the current survey Returns: A list of object IDs as strings """ utils.require_data_path(self._data_dir) return self._get_available_ids()
[docs] @ignore_warnings_wrapper def get_data_for_id(self, obj_id: str, format_table: bool = True) -> Table: """Returns data for a given object ID See ``get_available_ids()`` for a list of available ID values. Args: obj_id: The ID of the desired object format_table: Format data into the ``sndata`` standard format Returns: An astropy table of data for the given ID """ if obj_id not in self.get_available_ids(): raise InvalidObjId(f'Object Id not available: {obj_id}') return self._get_data_for_id(obj_id, format_table)
[docs] def iter_data( self, verbose: bool = False, format_table: bool = True, filter_func: bool = None) -> Table: """Iterate through all available targets and yield data tables An optional progress bar can be formatted by passing a dictionary of ``tqdm`` arguments. Outputs can be optionally filtered by passing a function ``filter_func`` that accepts a data table and returns a boolean. Args: verbose: Optionally display progress bar while iterating format_table: Format data for ``SNCosmo`` (Default: True) filter_func: An optional function to filter outputs by Yields: Astropy tables """ # Default to returning only non-empty tables if filter_func is None: filter_func = lambda x: x iterable = utils.build_pbar(self.get_available_ids(), verbose) for obj_id in iterable: data_table = self.get_data_for_id( obj_id, format_table=format_table) if filter_func(data_table): yield data_table
[docs] def delete_module_data(self): """Delete any data for the current survey / data release""" try: shutil.rmtree(self._data_dir) except FileNotFoundError: pass
[docs] def download_module_data(self, force: bool = False, timeout: float = 15): """Download data for the current survey / data release Args: force: Re-Download locally available data timeout: Seconds before timeout for individual files/archives """ if not hasattr(self, '_download_module_data'): raise RuntimeError( 'This data set does not support downloading remote data') self._download_module_data(force, timeout)
def __repr__(self): # Using self.__class__ ensures correct name appears for child classes class_name = self.__class__.__name__ return f'<{class_name} ({self.survey_abbrev} {self.release})>'
# noinspection PyUnresolvedReferences
[docs]class PhotometricRelease(SpectroscopicRelease): """Generic representation of a photometric data release This class is a template designed to enforce a consistent user interface and requires child classes to fill in incomplete functionality. """ data_type = 'photometric' # Photometric metadata @property def band_names(self) -> tuple: raise NotImplementedError('Band passes are not defined for this survey') @property def zero_point(self) -> tuple: raise NotImplementedError('Zero points are not defined for this survey')
[docs] @classmethod def get_zp_for_band(cls, band: str) -> str: """Get the zeropoint for a given band name Args: band: The name of the bandpass """ sorter = np.argsort(cls.band_names) indices = np.searchsorted(cls.band_names, band, sorter=sorter) return np.array(cls.zero_point)[sorter[indices]]
[docs] def register_filters(self, force: bool = False): """Register filters for this survey / data release with SNCosmo Args: force: Re-register a band if already registered """ utils.require_data_path(self._data_dir) self._register_filters(force)