#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
"""The ``base_classes`` module defines parent classes used by the data access
API to define basic data handling and to enforce a consistent user interface.
For an example on how to use these classes to create custom data access module
for a new survey / data release, see the :ref:`CustomClasses` section of the
docs.
"""
import functools
import shutil
import warnings
from typing import List
from typing import Union
import numpy as np
from astropy.io import ascii
from astropy.table import Table
from . import utils
from .exceptions import InvalidObjId, InvalidTableId
# Define short hand type for Ids of Vizier Tables
VizierTableId = Union[int, str]
def ignore_warnings_wrapper(func: callable) -> callable:
"""Ignores warnings issued by the wrapped function call"""
@functools.wraps(func)
def inner(*args, **kwargs):
with warnings.catch_warnings():
warnings.simplefilter('ignore')
return func(*args, **kwargs)
return inner
[docs]class DefaultParser:
"""Prebuilt data parsing tools for Vizier tables and photometric filters
For more information see the :ref:`CustomClasses` section of the docs.
"""
[docs] def _get_available_tables(self) -> List[VizierTableId]:
"""Default backend functionality of ``get_available_tables`` function"""
# Find available tables - assume standard Vizier naming scheme
# This includes assuming lowercase file names
table_nums = []
for f in self._table_dir.rglob('table*.dat'):
table_number = f.stem.lstrip('table')
try:
table_number = int(table_number)
except ValueError:
pass
table_nums.append(table_number)
return sorted(table_nums, key=str)
[docs] def _load_table(self, table_id: VizierTableId) -> Table:
"""Default backend functionality of ``load_table`` function"""
readme_path = self._table_dir / 'ReadMe'
table_path = self._table_dir / f'table{table_id}.dat'
# Read data from file and add meta data from the readme
data = ascii.read(str(table_path), format='cds', readme=str(readme_path))
description = utils.read_vizier_table_descriptions(readme_path)[table_id]
data.meta['description'] = description
return data
[docs] def _register_filters(self, force: bool = False):
"""Default backend functionality of ``register_filters`` function"""
bandpass_data = zip(self._filter_file_names, self.band_names)
for _file_name, _band_name in bandpass_data:
filter_path = self._filter_dir / _file_name
utils.register_filter_file(filter_path, _band_name, force=force)
[docs]class SpectroscopicRelease:
"""Generic representation of a spectroscopic data release
This class is a template designed to enforce a consistent user interface
and requires child classes to fill in incomplete functionality.
"""
# General metadata
publications = tuple()
ads_url = None
survey_name = None
survey_abbrev = None
release = None
survey_url = None
data_type = 'spectroscopic'
def __init__(self, survey_abbrev: str = None, release: str = None):
"""Represent Vizier data downloaded on the local machine
Args:
survey_abbrev: Abbreviation of the survey to load data for (e.g., CSP)
release: Name of the data release from the survey (e.g., DR1)
"""
err_msg = '``{}`` must either be passed at initialization or set as attribute'
if survey_abbrev is None and not hasattr(self, 'survey_abbrev'):
raise ValueError(err_msg.format('survey_abbrev'))
if release is None and not hasattr(self, 'release'):
raise ValueError(err_msg.format('release'))
self.survey_abbrev = survey_abbrev if survey_abbrev else self.survey_abbrev
self.release = release if release else self.release
self._data_dir = utils.find_data_dir(self.survey_abbrev, self.release)
self._table_dir = self._data_dir / 'tables'
[docs] def get_available_tables(self) -> List[VizierTableId]:
"""Get Ids for available vizier tables published by this data release"""
# Raise error if data is not downloaded
utils.require_data_path(self._data_dir)
return self._get_available_tables()
[docs] @utils.lru_copy_cache(maxsize=None)
@ignore_warnings_wrapper
def load_table(self, table_id: VizierTableId) -> Table:
"""Return a Vizier table published by this data release
Args:
table_id: The published table number or table name
"""
# Raise error if data is not downloaded
if table_id not in self.get_available_tables():
raise InvalidTableId(f'Table {table_id} is not available.')
return self._load_table(table_id)
[docs] def get_available_ids(self) -> List[str]:
"""Return a list of target object IDs for the current survey
Returns:
A list of object IDs as strings
"""
utils.require_data_path(self._data_dir)
return self._get_available_ids()
[docs] @ignore_warnings_wrapper
def get_data_for_id(self, obj_id: str, format_table: bool = True) -> Table:
"""Returns data for a given object ID
See ``get_available_ids()`` for a list of available ID values.
Args:
obj_id: The ID of the desired object
format_table: Format data into the ``sndata`` standard format
Returns:
An astropy table of data for the given ID
"""
if obj_id not in self.get_available_ids():
raise InvalidObjId(f'Object Id not available: {obj_id}')
return self._get_data_for_id(obj_id, format_table)
[docs] def iter_data(
self,
verbose: bool = False,
format_table: bool = True,
filter_func: bool = None) -> Table:
"""Iterate through all available targets and yield data tables
An optional progress bar can be formatted by passing a dictionary of
``tqdm`` arguments. Outputs can be optionally filtered by passing a
function ``filter_func`` that accepts a data table and returns a
boolean.
Args:
verbose: Optionally display progress bar while iterating
format_table: Format data for ``SNCosmo`` (Default: True)
filter_func: An optional function to filter outputs by
Yields:
Astropy tables
"""
# Default to returning only non-empty tables
if filter_func is None:
filter_func = lambda x: x
iterable = utils.build_pbar(self.get_available_ids(), verbose)
for obj_id in iterable:
data_table = self.get_data_for_id(
obj_id, format_table=format_table)
if filter_func(data_table):
yield data_table
[docs] def delete_module_data(self):
"""Delete any data for the current survey / data release"""
try:
shutil.rmtree(self._data_dir)
except FileNotFoundError:
pass
[docs] def download_module_data(self, force: bool = False, timeout: float = 15):
"""Download data for the current survey / data release
Args:
force: Re-Download locally available data
timeout: Seconds before timeout for individual files/archives
"""
if not hasattr(self, '_download_module_data'):
raise RuntimeError(
'This data set does not support downloading remote data')
self._download_module_data(force, timeout)
def __repr__(self):
# Using self.__class__ ensures correct name appears for child classes
class_name = self.__class__.__name__
return f'<{class_name} ({self.survey_abbrev} {self.release})>'
# noinspection PyUnresolvedReferences
[docs]class PhotometricRelease(SpectroscopicRelease):
"""Generic representation of a photometric data release
This class is a template designed to enforce a consistent user interface
and requires child classes to fill in incomplete functionality.
"""
data_type = 'photometric'
# Photometric metadata
@property
def band_names(self) -> tuple:
raise NotImplementedError('Band passes are not defined for this survey')
@property
def zero_point(self) -> tuple:
raise NotImplementedError('Zero points are not defined for this survey')
[docs] @classmethod
def get_zp_for_band(cls, band: str) -> str:
"""Get the zeropoint for a given band name
Args:
band: The name of the bandpass
"""
sorter = np.argsort(cls.band_names)
indices = np.searchsorted(cls.band_names, band, sorter=sorter)
return np.array(cls.zero_point)[sorter[indices]]
[docs] def register_filters(self, force: bool = False):
"""Register filters for this survey / data release with SNCosmo
Args:
force: Re-register a band if already registered
"""
utils.require_data_path(self._data_dir)
self._register_filters(force)