import json import threading import queue import time import traceback from typing import List, Dict, Callable, Any, Union from mod.base.ssh_executor import SSHExecutor from mod.project.node.dbutil import ServerNodeDB, CommandTask, CommandLog, TaskFlowsDB class CMDTask(object): def __init__(self, task: Union[int, CommandTask], log_id: int, call_update: Callable[[Any], None], exclude_nodes: List[int] = None): self._edb = TaskFlowsDB() if isinstance(task, int): self.task = self._edb.CommandTask.find("id = ?", (task,)) elif isinstance(task, CommandTask): self.task = task else: raise ValueError("task 参数错误") if not self.task: raise RuntimeError("指定任务不存在") if log_id == 0: self.task.elogs = self._edb.CommandLog.query("command_task_id = ? ", (self.task.id,)) else: self.task.elogs = [self._edb.CommandLog.find("command_task_id = ? AND id = ?", (self.task.id, log_id))] if not self.task.elogs: raise RuntimeError("任务无执行条目") self._exclude_nodes = exclude_nodes or [] self.task.elogs = [x for x in self.task.elogs if x.server_id not in self._exclude_nodes] self.task.status = 1 self._edb.CommandTask.update(self.task) self.end_queue = queue.Queue() self.end_status = False self.status: List[Dict] = [] self.call_update = call_update self.status_dict: Dict[str, Union[List[Any], int]] = { "task_id": self.task.id, "task_type": "command", "flow_idx": self.task.step_index -1, "count": len(self.task.elogs), "complete": 0, "error": 0, "exclude_nodes": self._exclude_nodes, "error_nodes": [], "data": [], } def end_func(self): edb = TaskFlowsDB() tmp_dict: Dict[int, CommandLog] = {} last_time = time.time() update_fields=("status",) complete_set, error_set = set(), set() while True: try: elog: CommandLog = self.end_queue.get(timeout=0.1) except queue.Empty: if self.end_status: break else: continue except Exception as e: print(e) break if elog.status in (3, 4): error_set.add(elog.id) self.status_dict["error_nodes"].append(int(elog.server_id)) self.status_dict["error"] = len(error_set) elif elog.status == 2: complete_set.add(elog.id) self.status_dict["complete"] = len(complete_set) tmp_dict[elog.id] = elog if time.time() - last_time > 0.5: edb.CommandLog.bath_update(tmp_dict.values(), update_fields=update_fields) self.status_dict["data"] = [ l.to_show_data() for l in tmp_dict.values()] self.call_update(self.status_dict) tmp_dict.clear() if tmp_dict: edb.CommandLog.bath_update(tmp_dict.values(), update_fields=update_fields) self.status_dict["data"] = [ l.to_show_data() for l in tmp_dict.values()] self.call_update(self.status_dict) return def start(self): thread_list = [] s_db = ServerNodeDB() end_th = threading.Thread(target=self.end_func) end_th.start() for (idx, log) in enumerate(self.task.elogs): log.log_idx = idx if log.status == 2: # 跳过已完成的 self.end_queue.put(log) continue log.status = 1 ssh_conf = None node = s_db.get_node_by_id(log.server_id) if not node: log.status = 3 log.write_log("节点数据丢失,无法执行\n") else: ssh_conf = json.loads(node["ssh_conf"]) if not ssh_conf: log.status = 3 log.write_log("节点ssh配置数据丢失,无法执行\n") self.end_queue.put(log) if not ssh_conf: continue thread = threading.Thread(target=self.run_one, args=(ssh_conf, log)) thread.start() thread_list.append(thread) for i in thread_list: i.join() self.end_status = True end_th.join() if self.status_dict["error"] > 0: self.task.status = 3 else: self.task.status = 2 self._edb.CommandTask.update(self.task) self._edb.close() def run_one(self, ssh_conf: dict, elog: CommandLog): ssh = SSHExecutor( host=ssh_conf["host"], port=ssh_conf["port"], username=ssh_conf["username"], password=ssh_conf["password"], key_data=ssh_conf["pkey"], passphrase=ssh_conf["pkey_passwd"]) elog.write_log("开始执行任务\n开始建立ssh连接...\n") try: ssh.open() def on_stdout(data): if isinstance(data, bytes): data = data.decode() elog.write_log(data) elog.write_log("开始执行脚本...\n\n") t = time.time() res_code = ssh.execute_script_streaming( script_content=self.task.script_content, script_type=self.task.script_type, timeout=60*60, on_stdout=on_stdout, on_stderr=on_stdout ) take_time = round((time.time() - t)* 1000, 2) elog.write_log("\n\n执行结束,耗时[{}ms]\n".format(take_time)) if res_code == 0: elog.status = 2 elog.write_log("任务完成\n", is_end_log=True) else: elog.status = 4 elog.write_log("任务异常,返回状态码为:{}\n".format(res_code), is_end_log=True) self.end_queue.put(elog) except Exception as e: traceback.print_exc() elog.status = 3 elog.write_log("\n任务失败,错误:" + str(e), is_end_log=True) self.end_queue.put(elog) return # 同步执行命令相关任务的重试 def command_task_run_sync(task_id: int, log_id: int) -> Union[str, Dict[str, Any]]: fdb = TaskFlowsDB() task = fdb.CommandTask.get_byid(task_id) if not task: return "任务不存在" log = fdb.CommandLog.get_byid(log_id) if not log: return "子任务不存在" if log.status not in (3, 4): return "子任务状态不为失败,无法重试" if log.command_task_id != task_id: return "子任务不属于该任务,无法重试" cmd_task = CMDTask(task, log_id=log_id, call_update=print) cmd_task.start() return cmd_task.status_dict