| import os |
| import gc |
| import time |
| import shutil |
| import logging |
| from pathlib import Path |
| from huggingface_hub import WebhooksServer, WebhookPayload |
| from datasets import Dataset, load_dataset, disable_caching |
| from fastapi import BackgroundTasks, Response, status |
|
|
|
|
| def clear_huggingface_cache(): |
| |
| cache_dir = Path.home() / ".cache" / "huggingface" / "datasets" |
|
|
| |
| if cache_dir.exists() and cache_dir.is_dir(): |
| shutil.rmtree(cache_dir) |
| print(f"Removed cache directory: {cache_dir}") |
| else: |
| print("Cache directory does not exist.") |
|
|
|
|
| |
| disable_caching() |
|
|
| |
| logger = logging.getLogger("basic_logger") |
| logger.setLevel(logging.INFO) |
|
|
| |
| console_handler = logging.StreamHandler() |
| console_handler.setLevel(logging.INFO) |
| formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
| console_handler.setFormatter(formatter) |
| logger.addHandler(console_handler) |
|
|
| |
| DS_NAME = "amaye15/object-segmentation" |
| DATA_DIR = Path("data") |
| TARGET_REPO = "amaye15/object-segmentation-processed" |
| WEBHOOK_SECRET = os.getenv("HF_WEBHOOK_SECRET") |
|
|
|
|
| def get_data(): |
| """ |
| Generator function to stream data from the dataset. |
| |
| Uses streaming to avoid loading the entire dataset into memory at once, |
| which is useful for handling large datasets. |
| """ |
| ds = load_dataset( |
| DS_NAME, |
| streaming=True, |
| ) |
| for row in ds["train"]: |
| yield row |
|
|
|
|
| def process_and_push_data(): |
| """ |
| Function to process and push new data to the target repository. |
| |
| Removes existing data directory if it exists, recreates it, processes |
| the dataset, and pushes the processed dataset to the hub. |
| """ |
|
|
| |
| ds_processed = Dataset.from_generator(get_data) |
| ds_processed.push_to_hub(TARGET_REPO, max_shard_size="1GB") |
|
|
| logger.info("Data processed and pushed to the hub.") |
|
|
|
|
| |
| app = WebhooksServer(webhook_secret=WEBHOOK_SECRET) |
|
|
|
|
| @app.add_webhook("/dataset_repo") |
| async def handle_repository_changes( |
| payload: WebhookPayload, task_queue: BackgroundTasks |
| ): |
| """ |
| Webhook endpoint that triggers data processing when the dataset is updated. |
| |
| Adds a task to the background task queue to process the dataset |
| asynchronously. |
| """ |
| time.sleep(15) |
| clear_huggingface_cache() |
| logger.info( |
| f"Webhook received from {payload.repo.name} indicating a repo {payload.event.action}" |
| ) |
| task_queue.add_task(_process_webhook) |
| return Response("Task scheduled.", status_code=status.HTTP_202_ACCEPTED) |
|
|
|
|
| def _process_webhook(): |
| """ |
| Private function to handle the processing of the dataset when a webhook |
| is triggered. |
| |
| Loads the dataset, processes it, and pushes the processed data to the hub. |
| """ |
| logger.info("Loading new dataset...") |
| |
| logger.info("Loaded new dataset") |
|
|
| logger.info("Processing and updating dataset...") |
| process_and_push_data() |
| logger.info("Processing and updating dataset completed!") |
|
|
|
|
| if __name__ == "__main__": |
| app.launch(server_name="0.0.0.0", show_error=True, server_port=7860) |
|
|