Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions tests/dummy_tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import pyvo as vo
import time
import logging
import json
from hashlib import md5
from pathlib import Path
from astropy.table import Table

Expand Down Expand Up @@ -30,9 +32,18 @@ class DummyAsyncTAPJob:
Hacky drop-in replacement for AsyncTAPJob
"""

def __init__(self, url, *, session=None, delete=True, fail_fetch=False):
def __init__(
self,
url,
*,
session=None,
delete=True,
fail_fetch=False,
final_phase="COMPLETED",
):
self.url = url
self.fail_fetch = fail_fetch
self.final_phase = final_phase

def _update(self, wait_for_statechange=False, timeout=10.0):
pass
Expand Down Expand Up @@ -86,7 +97,7 @@ def create(
def phase(self):
created = float(self.url.split("created")[-1])
if time.time() - created > 1:
return "COMPLETED"
return self.final_phase
return "RUNNING"

def fetch_result(self):
Expand Down Expand Up @@ -120,6 +131,7 @@ def __init__(
session=None,
fail_submit=False,
fail_fetch=False,
final_job_phase="COMPLETED",
sync_res: Table | None = None,
):
super(DummyTAPService, self).__init__(
Expand All @@ -129,13 +141,29 @@ def __init__(
self.fail_submit = fail_submit
self.fail_fetch = fail_fetch
self.sync_res = sync_res
self.final_job_phase = final_job_phase
self.tries = {}
self.keys = {}

def submit_job(
self, query, *, language="ADQL", maxrec=None, uploads=None, **keywords
):
if self.fail_submit:
raise vo.dal.exceptions.DALServiceError("failed submit")

key = md5(
(
query
+ json.dumps({k: v.to_pandas().to_dict() for k, v in uploads.items()})
).encode()
).hexdigest()
n_try = self.tries.get(key, 0)
final_phase = (
self.final_job_phase
if isinstance(self.final_job_phase, str)
else self.final_job_phase[n_try]
)

job = DummyAsyncTAPJob.create(
self.baseurl,
query,
Expand All @@ -144,15 +172,26 @@ def submit_job(
uploads=uploads,
session=self._session,
chunksize=self.chunksize,
final_phase=final_phase,
**keywords,
)
logger.debug(job.url)
self.tries[key] = n_try + 1
self.keys[job.url] = key
assert job.phase
return job

def get_job_from_url(self, url):
if isinstance(self.final_job_phase, str):
final_phase = self.final_job_phase
else:
n_try = self.tries[self.keys[url]]
final_phase = self.final_job_phase[n_try - 1]
return DummyAsyncTAPJob(
url=url, session=self._session, fail_fetch=self.fail_fetch
url=url,
session=self._session,
fail_fetch=self.fail_fetch,
final_phase=final_phase,
)

def run_sync(
Expand Down
20 changes: 20 additions & 0 deletions tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,23 @@ def test_downloader_creates_files(download_cfg):
else:
m = [True] * len(reference)
assert sum(np.array(reference[col] - produced[col])[m]) == 0


@pytest.mark.parametrize("resubmit", [False, True])
def test_resubmit(download_cfg, resubmit):
dl = download_cfg.build_downloader(resubmit_failed=resubmit)

# mimick behavior when job results disappear from server
dl.service = DummyTAPService(
baseurl="",
chunksize=download_cfg.chunk_size,
final_job_phase=[None, "COMPLETED"],
)

# make sure that downloader fails / succeeds as expected when not retrying
dl.run()
for t in dl.iter_tasks():
if resubmit:
assert dl.backend.is_done(t)
else:
assert not dl.backend.is_done(t)
2 changes: 2 additions & 0 deletions timewise/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def meta_exists(self, task: TaskID) -> bool: ...
def save_meta(self, task: TaskID, meta: dict[str, Any]) -> None: ...
@abc.abstractmethod
def load_meta(self, task: TaskID) -> dict[str, Any] | None: ...
@abc.abstractmethod
def drop_meta(self, task: TaskID) -> None: ...

# --- Markers ---
@abc.abstractmethod
Expand Down
3 changes: 3 additions & 0 deletions timewise/backend/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def load_meta(self, task: TaskID) -> dict[str, Any] | None:
def meta_exists(self, task: TaskID) -> bool:
return self._meta_path(task).exists()

def drop_meta(self, task: TaskID) -> None:
self._meta_path(task).unlink()

# ----------------------------
# Markers
# ----------------------------
Expand Down
10 changes: 9 additions & 1 deletion timewise/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,16 @@ def main(
@app.command(help="Download WISE photometry from IRSA")
def download(
config_path: config_path_type,
resubmit_failed: Annotated[
bool,
typer.Option(
help="Re-submit jobs when failed due to connection issues",
),
] = False,
):
TimewiseConfig.from_yaml(config_path).download.build_downloader().run()
TimewiseConfig.from_yaml(config_path).download.build_downloader(
resubmit_failed=resubmit_failed
).run()


# the following commands will only be added if ampel is installed
Expand Down
24 changes: 14 additions & 10 deletions timewise/io/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class DownloadConfig(BaseModel):
poll_interval: float = 10.0
queries: List[QueryType] = Field(..., description="One or more queries per chunk")
backend: BackendType = Field(..., discriminator="type")
resubmit_failed: bool = False

service_url: str = "https://irsa.ipac.caltech.edu/TAP"

Expand Down Expand Up @@ -57,13 +58,16 @@ def validate_input_csv_columns(self) -> "DownloadConfig":

return self

def build_downloader(self) -> Downloader:
return Downloader(
service_url=self.service_url,
input_csv=self.expanded_input_csv,
chunk_size=self.chunk_size,
backend=self.backend,
queries=self.queries,
max_concurrent_jobs=self.max_concurrent_jobs,
poll_interval=self.poll_interval,
)
def build_downloader(self, **overwrite) -> Downloader:
default = {
"service_url": self.service_url,
"input_csv": self.expanded_input_csv,
"chunk_size": self.chunk_size,
"backend": self.backend,
"queries": self.queries,
"max_concurrent_jobs": self.max_concurrent_jobs,
"poll_interval": self.poll_interval,
"resubmit_failed": self.resubmit_failed,
}
default.update(overwrite)
return Downloader(**default) # type: ignore
29 changes: 26 additions & 3 deletions timewise/io/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from itertools import product
from pathlib import Path
from queue import Empty
from typing import Dict, Iterator, Sequence, cast
from typing import Dict, Iterator

import numpy as np
import pandas as pd
from astropy.table import Table
from pyvo.utils.http import create_session

Expand All @@ -32,6 +31,7 @@ def __init__(
queries: list[QueryType],
max_concurrent_jobs: int,
poll_interval: float,
resubmit_failed: bool,
):
self.backend = backend
self.queries = queries
Expand Down Expand Up @@ -65,6 +65,7 @@ def __init__(
self.service: StableTAPService = StableTAPService(
service_url, session=self.session
)
self.resubmit_failed = resubmit_failed

self.chunker = Chunker(input_csv=input_csv, chunk_size=chunk_size)

Expand Down Expand Up @@ -131,7 +132,6 @@ def submit_tap_job(self, query: QueryType, chunk: Chunk) -> TAPJobMeta:
logger.debug(f"uploading {len(upload)} objects.")
job = self.service.submit_job(adql, uploads={query.upload_name: upload})
job.run()
logger.debug(job.url)

return TAPJobMeta(
url=job.url,
Expand Down Expand Up @@ -192,6 +192,26 @@ def _submission_worker(self):
# ----------------------------
# Polling thread
# ----------------------------

def resubmit(self, resubmit_task: TaskID):
logger.info(f"resubmitting {resubmit_task}")
submit = None
for chunk, q in product(self.chunker, self.queries):
task = self.get_task_id(chunk, q)
if task == resubmit_task:
submit = chunk, q
break
if submit is None:
raise RuntimeError(f"resubmit task {resubmit_task} not found!")

# remove current info, so the job won't be re-submitted over and over again
self.backend.drop_meta(resubmit_task)
with self.job_lock:
self.jobs.pop(resubmit_task)

# put task back in resubmit queue
self.submit_queue.put(submit)

def _polling_worker(self):
logger.debug("starting polling worker")
backend = self.backend
Expand Down Expand Up @@ -223,6 +243,9 @@ def _polling_worker(self):
f"No job found under {meta['url']} for {task}! "
f"Probably took too long before downloading results."
)
if self.resubmit_failed:
self.resubmit(task)
continue

meta["status"] = status
with self.job_lock:
Expand Down
17 changes: 10 additions & 7 deletions timewise/io/stable_tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

import requests

from timewise.util.backoff import backoff_hndlr


logger = logging.getLogger(__name__)


Expand All @@ -26,7 +23,6 @@ def __init__(self, url, *, session=None, delete=True):
backoff.expo,
requests.exceptions.HTTPError,
max_tries=5,
on_backoff=backoff_hndlr,
)
def create(
cls,
Expand Down Expand Up @@ -92,7 +88,6 @@ def create(
backoff.expo,
(vo.dal.DALServiceError, AttributeError),
max_tries=50,
on_backoff=backoff_hndlr,
)
def phase(self):
return super(StableAsyncTAPJob, self).phase
Expand All @@ -101,7 +96,6 @@ def phase(self):
backoff.expo,
vo.dal.DALServiceError,
max_tries=50,
on_backoff=backoff_hndlr,
)
def _update(self, *args, **kwargs):
return super(StableAsyncTAPJob, self)._update(*args, **kwargs)
Expand All @@ -116,7 +110,6 @@ class StableTAPService(vo.dal.TAPService):
backoff.expo,
(vo.dal.DALServiceError, AttributeError, AssertionError),
max_tries=5,
on_backoff=backoff_hndlr,
)
def submit_job(
self, query, *, language="ADQL", maxrec=None, uploads=None, **keywords
Expand All @@ -136,3 +129,13 @@ def submit_job(

def get_job_from_url(self, url):
return StableAsyncTAPJob(url, session=self._session)

@backoff.on_exception(
backoff.expo,
(vo.dal.DALServiceError, vo.dal.DALFormatError),
max_tries=5,
)
def run_sync(
self, query, *, language="ADQL", maxrec=None, uploads=None,
**keywords):
return super().run_sync(query, language=language, maxrec=maxrec, uploads=uploads, **keywords)
4 changes: 1 addition & 3 deletions timewise/plot/sdss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import matplotlib.pyplot as plt
import backoff

from ..util.backoff import backoff_hndlr


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -34,7 +32,7 @@ def login_to_sciserver():


@backoff.on_exception(
backoff.expo, requests.RequestException, max_tries=50, on_backoff=backoff_hndlr
backoff.expo, requests.RequestException, max_tries=50
)
def get_cutout(*args, **kwargs):
login_to_sciserver()
Expand Down
1 change: 0 additions & 1 deletion timewise/query/positional.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,4 @@ def build(self) -> str:
q = q.strip(" AND \n")
q += "\t)"

logger.debug(f"\n{q}")
return q
12 changes: 0 additions & 12 deletions timewise/util/backoff.py

This file was deleted.

Loading