File size: 6,054 Bytes
17e971c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import json
import os
import threading
import queue
import time
import traceback
from typing import List, Dict, Callable, Any, Union, Optional, Tuple

from mod.base.ssh_executor import SSHExecutor
from mod.project.node.dbutil import ServerNodeDB, CommandTask, CommandLog, TaskFlowsDB, TransferTask
from mod.project.node.dbutil import TaskFlowsDB
from mod.project.node.nodeutil import LPanelNode, ServerNode, SSHApi
from mod.project.node.filetransfer.socket_server import StatusServer, StatusClient, register_cleanup

from .command_task import CMDTask
from .file_task import FiletransferTask, NodeFiletransferTask

_SOCKET_FILE_DIR = "/tmp/flow_task"
if not os.path.exists(_SOCKET_FILE_DIR):
    os.mkdir(_SOCKET_FILE_DIR)



class FlowTask:

    def __init__(self, flow_id: int, step_idx: int=0, sub_id: int=0):
        self._fdb = TaskFlowsDB()
        self.flow = self._fdb.Flow.get_byid(flow_id)
        if not self.flow:
            raise RuntimeError("任务不存在")

        self.steps: List[Union[CommandTask, TransferTask]] = [
            *self._fdb.CommandTask.query("flow_id = ?", (flow_id,)),
            *self._fdb.TransferTask.query("flow_id = ?", (flow_id,))
        ]

        self.steps.sort(key=lambda x: x.step_index, reverse=False)

        if not self.steps:
            raise RuntimeError("任务内容不存在")
        self.now_idx = 1
        # 当任意错误出现时,是否继续执行
        self.run_when_error = False
        if self.flow.strategy.get("run_when_error", False):
            self.run_when_error = True
        # 当某个节点出错时,是否在后续步骤中跳过
        self.exclude_when_error = True
        if not self.flow.strategy.get("exclude_when_error", True):
            self.exclude_when_error = False

        self.status_server = StatusServer(self.get_status, (_SOCKET_FILE_DIR + "/flow_task_" + str(flow_id)))
        self.flow_all_nodes = set([int(i) for i in self.flow.server_ids.split("|") if i and i.isdigit()])

    def get_status(self, init: bool = False):
        flow_data = self.flow.to_dict()
        flow_data["steps"] = [x.to_show_data() for x in self.steps]
        flow_data["now_idx"] = self.now_idx
        return flow_data

    def start_status_server(self):
        t = threading.Thread(target=self.status_server.start_server, args=(), daemon=True)
        t.start()
        register_cleanup(self.status_server)

    def update_status(self, update_data: Dict):
        self.status_server.update_status(update_data)

    def _run(self) -> bool:
        def call_log(log_data):
            self.update_status(log_data)

        all_status = True  # 任务全部成功
        error_nodes = set()
        for step in self.steps:
            if not (self.flow_all_nodes - error_nodes): # 没有节点可执行
                continue
            if isinstance(step, CommandTask):
                if step.status != 2: # 跳过已完成的
                    has_err, task_error_nodes = self.run_cmd_task(step, call_log, exclude_nodes=list(error_nodes))
                    all_status = all_status and not has_err
                    if has_err and not self.run_when_error:
                        return False
                    if self.exclude_when_error and task_error_nodes:
                        error_nodes.update(task_error_nodes)
            elif isinstance(step, TransferTask):
                if step.status != 2: # 跳过已完成的
                    has_err, task_error_nodes = self.run_transfer_task(step, call_log, exclude_nodes=list(error_nodes))
                    all_status = all_status and not has_err
                    if has_err and not self.run_when_error:
                        return False
                    if self.exclude_when_error and task_error_nodes:
                        error_nodes.update(task_error_nodes)
            self.now_idx += 1
        return all_status

    def start(self):
        self.start_status_server()

        self.flow.status = "running"
        self._fdb.Flow.update(self.flow)
        all_status = self._run()
        self.flow.status = "complete" if all_status else "error"
        self._fdb.Flow.update(self.flow)

        self.status_server.stop()
        # fdb = TaskFlowsDB()
        # print(fdb.history_flow_task(self.flow.id))
        return

    @staticmethod
    def run_cmd_task(task: CommandTask, call_log: Callable[[Any], None], exclude_nodes: List[int] = None) -> Tuple[bool, List[int]]:
        task = CMDTask(task, 0, call_log, exclude_nodes=exclude_nodes)
        task.start()
        return task.status_dict["error"] > 0, task.status_dict["error_nodes"]

    @staticmethod
    def run_transfer_task(task: TransferTask, call_log: Callable[[Any], None], exclude_nodes: List[int] = None) -> Tuple[bool, List[int]]:
        if task.src_node_task_id != 0:
            task = NodeFiletransferTask(task, call_log, exclude_nodes=exclude_nodes, the_log_id=None)
            task.start()
            return task.status_dict["error"] > 0, task.status_dict["error_nodes"]
        else:
            task = FiletransferTask(task, call_log, exclude_nodes=exclude_nodes)
            task.start()
            return task.status_dict["error"] > 0, task.status_dict["error_nodes"]


def flow_running_log(task_id: int, call_log:  Callable[[Union[str,dict]], None], timeout:float = 3.0) -> str:
    socket_file = _SOCKET_FILE_DIR + "/flow_task_" + str(task_id)
    while not os.path.exists(socket_file):
        if timeout <= 0:
            return "任务启动超时"
        timeout -= 0.05
        time.sleep(0.05)

    s_client = StatusClient(socket_file, callback=call_log)
    s_client.connect()
    s_client.wait_receive()
    return  ""

def flow_useful_version(ver: str):
    # # todo: 临时处理, 上线前确认最新版本号检查逻辑
    # return True
    try:
        ver_list = [int(i) for i in ver.split(".")]
        if ver_list[0] > 11:
            return True
        if ver_list[0] == 11 and ver_list[1] >= 4:
            return True
    except:
        pass
    return False