Created
May 6, 2022 11:09
-
-
Save taldcroft/44a88f079f4afe85ae2a14c75ac2795d to your computer and use it in GitHub Desktop.
Table subclass with parameters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from astropy.table import Table, Column | |
from astropy.table.ndarray_mixin import NdarrayMixin, NdarrayMixinInfo | |
from astropy.io.registry import UnifiedReadWriteMethod | |
from astropy.table.connect import TableRead, TableWrite | |
from astropy.table.info import serialize_method_as | |
def fmt_func(val): | |
return f'{val["par"]:.2f} ({val["pmn"]:.2f}, {val["pmx"]:.2f})' | |
PAR_DTYPE = np.dtype([('par', 'f8'), ('pmn', 'f8'), ('pmx', 'f8')]) | |
class ParTableRead(TableRead): | |
def __call__(self, *args, **kwargs): | |
out = super().__call__(*args, **kwargs) | |
for col in out.columns.values(): | |
if col.dtype == PAR_DTYPE: | |
col.info.format = fmt_func | |
return out | |
class ParTableWrite(TableWrite): | |
"""Something like this SHOULD work, but there is something obscure and | |
weird going on that makes this fail. So instead just override the | |
write() method directly.""" | |
def __call__(self, *args, serialize_method=None, **kwargs): | |
instance = self._instance | |
par_cols = [] | |
try: | |
for col in instance.columns.values(): | |
if col.info.format is fmt_func: | |
par_cols.append(col.info.name) | |
col.info.format = None | |
return super().__call__(*args, **kwargs) | |
finally: | |
for par_col in par_cols: | |
self[par_col].info.format = fmt_func | |
class ParTable(Table): | |
read = UnifiedReadWriteMethod(ParTableRead) | |
# write = UnifiedReadWriteMethod(ParTableWrite) | |
def _convert_data_to_col(self, *args, **kwargs): | |
col = super()._convert_data_to_col(*args, **kwargs) | |
if col.dtype == PAR_DTYPE: | |
col.info.format = fmt_func | |
return col | |
def write(self, *args, **kwargs): | |
new_self = Table(self) | |
for col in new_self.columns.values(): | |
if col.info.format is fmt_func: | |
col.info.format = None | |
return new_self.write(*args, **kwargs) | |
name = ['par1', 'par2'] | |
a = np.array([(np.pi, 2, 4.5), (np.pi / 2, 1, 3.1)], dtype=PAR_DTYPE) | |
t = ParTable([name, a], names=['name', 'par']) | |
t.write('pars.ecsv', overwrite=True) | |
t2 = ParTable.read('pars.ecsv') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment