File size: 1,773 Bytes
0775c37
 
95f4708
 
0775c37
 
 
 
 
 
 
 
 
 
5508593
0775c37
 
 
 
5508593
0775c37
5508593
0775c37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import asyncio
import uuid
import db
from worker import run_job_async

class JobQueue:
    async def add_job(self, req, token):
        job_id = str(uuid.uuid4())
        params = {
            "model_a_source": req.model_a_source,
            "model_a_id": req.model_a_id,
            "model_b_source": req.model_b_source,
            "model_b_id": req.model_b_id,
            "method": req.method,
            "merge_type": req.merge_type,
            "linear_alpha": req.linear_alpha,
            "evo_params": req.evo_params,
            "dataset": req.dataset,
            "output_repo_name": req.output_repo_name,
            "hf_token": token,
            "civitai_key": req.civitai_key,
            "franken_layers": req.franken_layers,
        }
        await db.create_job(job_id, params)
        return job_id

    async def process_jobs(self):
        while True:
            job = await db.get_next_pending()
            if job:
                await db.update_job(job["id"], "running")
                try:
                    result = await run_job_async(job["id"], job["params"])
                    await db.update_job(job["id"], "completed", result)
                except Exception as e:
                    await db.update_job(job["id"], "failed", {"error": str(e)})
                finally:
                    params = job["params"]
                    params["hf_token"] = None
                    params["civitai_key"] = None
                    await db.update_job(job["id"], "completed", job.get("result"))
            await asyncio.sleep(1)

    async def get_queue_status(self):
        counts = await db.get_queue_counts()
        return {
            "queued": counts.get("queued", 0),
            "running": counts.get("running", 0)
        }