File size: 7,051 Bytes
3a5cf48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import json
import threading
import queue
import time
import traceback
from typing import List, Dict, Callable, Any, Union

from mod.base.ssh_executor import SSHExecutor
from mod.project.node.dbutil import ServerNodeDB, CommandTask, CommandLog, TaskFlowsDB


class CMDTask(object):

    def __init__(self, task: Union[int, CommandTask], log_id: int, call_update: Callable[[Any], None], exclude_nodes: List[int] = None):
        self._edb = TaskFlowsDB()
        if isinstance(task, int):
            self.task = self._edb.CommandTask.find("id = ?", (task,))
        elif isinstance(task, CommandTask):
            self.task = task
        else:
            raise ValueError("task 参数错误")
        if not self.task:
            raise RuntimeError("指定任务不存在")
        if log_id == 0:
            self.task.elogs = self._edb.CommandLog.query("command_task_id = ? ", (self.task.id,))
        else:
            self.task.elogs = [self._edb.CommandLog.find("command_task_id = ? AND id = ?", (self.task.id, log_id))]
        if not self.task.elogs:
            raise RuntimeError("任务无执行条目")

        self._exclude_nodes = exclude_nodes or []
        self.task.elogs = [x for x in self.task.elogs if x.server_id not in self._exclude_nodes]

        self.task.status = 1
        self._edb.CommandTask.update(self.task)
        self.end_queue  = queue.Queue()
        self.end_status = False
        self.status: List[Dict] = []
        self.call_update = call_update
        self.status_dict: Dict[str, Union[List[Any], int]] = {
            "task_id": self.task.id,
            "task_type": "command",
            "flow_idx": self.task.step_index -1,
            "count": len(self.task.elogs),
            "complete": 0,
            "error": 0,
            "exclude_nodes": self._exclude_nodes,
            "error_nodes": [],
            "data": [],
        }

    def end_func(self):
        edb = TaskFlowsDB()
        tmp_dict: Dict[int, CommandLog] = {}
        last_time = time.time()
        update_fields=("status",)
        complete_set, error_set = set(), set()
        while True:
            try:
                elog: CommandLog = self.end_queue.get(timeout=0.1)
            except queue.Empty:
                if self.end_status:
                    break
                else:
                    continue
            except Exception as e:
                print(e)
                break

            if elog.status in (3, 4):
                error_set.add(elog.id)
                self.status_dict["error_nodes"].append(int(elog.server_id))
                self.status_dict["error"] = len(error_set)
            elif elog.status == 2:
                complete_set.add(elog.id)
                self.status_dict["complete"] = len(complete_set)

            tmp_dict[elog.id] = elog
            if time.time() - last_time > 0.5:
                edb.CommandLog.bath_update(tmp_dict.values(), update_fields=update_fields)
                self.status_dict["data"] = [ l.to_show_data() for l in tmp_dict.values()]
                self.call_update(self.status_dict)
                tmp_dict.clear()

        if tmp_dict:
            edb.CommandLog.bath_update(tmp_dict.values(), update_fields=update_fields)
            self.status_dict["data"] = [ l.to_show_data() for l in tmp_dict.values()]
            self.call_update(self.status_dict)

        return

    def start(self):
        thread_list = []
        s_db = ServerNodeDB()
        end_th = threading.Thread(target=self.end_func)
        end_th.start()

        for (idx, log) in enumerate(self.task.elogs):
            log.log_idx = idx
            if log.status == 2: # 跳过已完成的
                self.end_queue.put(log)
                continue

            log.status = 1
            ssh_conf = None
            node = s_db.get_node_by_id(log.server_id)
            if not node:
                log.status = 3
                log.write_log("节点数据丢失,无法执行\n")

            else:
                ssh_conf = json.loads(node["ssh_conf"])
                if not ssh_conf:
                    log.status = 3
                    log.write_log("节点ssh配置数据丢失,无法执行\n")

            self.end_queue.put(log)

            if not ssh_conf:
                continue

            thread = threading.Thread(target=self.run_one, args=(ssh_conf, log))
            thread.start()
            thread_list.append(thread)

        for i in thread_list:
            i.join()
        self.end_status = True
        end_th.join()
        if self.status_dict["error"] > 0:
            self.task.status = 3
        else:
            self.task.status = 2
        self._edb.CommandTask.update(self.task)
        self._edb.close()

    def run_one(self, ssh_conf: dict, elog: CommandLog):
        ssh = SSHExecutor(
            host=ssh_conf["host"],
            port=ssh_conf["port"],
            username=ssh_conf["username"],
            password=ssh_conf["password"],
            key_data=ssh_conf["pkey"],
            passphrase=ssh_conf["pkey_passwd"])
        elog.write_log("开始执行任务\n开始建立ssh连接...\n")
        try:
            ssh.open()
            def on_stdout(data):
                if isinstance(data, bytes):
                    data = data.decode()
                elog.write_log(data)

            elog.write_log("开始执行脚本...\n\n")
            t = time.time()
            res_code = ssh.execute_script_streaming(
                script_content=self.task.script_content,
                script_type=self.task.script_type,
                timeout=60*60,
                on_stdout=on_stdout,
                on_stderr=on_stdout
            )
            take_time = round((time.time() - t)* 1000, 2)
            elog.write_log("\n\n执行结束,耗时[{}ms]\n".format(take_time))
            if res_code == 0:
                elog.status = 2
                elog.write_log("任务完成\n", is_end_log=True)
            else:
                elog.status = 4
                elog.write_log("任务异常,返回状态码为:{}\n".format(res_code), is_end_log=True)
            self.end_queue.put(elog)
        except Exception as e:
            traceback.print_exc()
            elog.status = 3
            elog.write_log("\n任务失败,错误:" + str(e), is_end_log=True)
            self.end_queue.put(elog)
            return


# 同步执行命令相关任务的重试
def command_task_run_sync(task_id: int, log_id: int) -> Union[str, Dict[str, Any]]:
    fdb = TaskFlowsDB()
    task = fdb.CommandTask.get_byid(task_id)
    if not task:
        return "任务不存在"
    log = fdb.CommandLog.get_byid(log_id)
    if not log:
        return "子任务不存在"
    if log.status not in (3, 4):
        return "子任务状态不为失败,无法重试"
    if log.command_task_id != task_id:
        return "子任务不属于该任务,无法重试"
    cmd_task = CMDTask(task, log_id=log_id, call_update=print)
    cmd_task.start()
    return cmd_task.status_dict