| import os |
| import sys |
| import traceback |
| import queue |
| import threading |
| import time |
| import numpy as np |
| import itertools |
| from pathlib import Path |
| from core import pathex |
| from core import imagelib |
| import cv2 |
| import models |
| from core.interact import interact as io |
|
|
| def trainerThread (s2c, c2s, e, |
| model_class_name = None, |
| saved_models_path = None, |
| training_data_src_path = None, |
| training_data_dst_path = None, |
| pretraining_data_path = None, |
| pretrained_model_path = None, |
| no_preview=False, |
| force_model_name=None, |
| force_gpu_idxs=None, |
| cpu_only=None, |
| silent_start=False, |
| execute_programs = None, |
| debug=False, |
| **kwargs): |
| while True: |
| try: |
| start_time = time.time() |
|
|
| save_interval_min = 25 |
|
|
| if not training_data_src_path.exists(): |
| training_data_src_path.mkdir(exist_ok=True, parents=True) |
|
|
| if not training_data_dst_path.exists(): |
| training_data_dst_path.mkdir(exist_ok=True, parents=True) |
|
|
| if not saved_models_path.exists(): |
| saved_models_path.mkdir(exist_ok=True, parents=True) |
| |
| model = models.import_model(model_class_name)( |
| is_training=True, |
| saved_models_path=saved_models_path, |
| training_data_src_path=training_data_src_path, |
| training_data_dst_path=training_data_dst_path, |
| pretraining_data_path=pretraining_data_path, |
| pretrained_model_path=pretrained_model_path, |
| no_preview=no_preview, |
| force_model_name=force_model_name, |
| force_gpu_idxs=force_gpu_idxs, |
| cpu_only=cpu_only, |
| silent_start=silent_start, |
| debug=debug) |
|
|
| is_reached_goal = model.is_reached_iter_goal() |
|
|
| shared_state = { 'after_save' : False } |
| loss_string = "" |
| save_iter = model.get_iter() |
| def model_save(): |
| if not debug and not is_reached_goal: |
| io.log_info ("Saving....", end='\r') |
| model.save() |
| shared_state['after_save'] = True |
| |
| def model_backup(): |
| if not debug and not is_reached_goal: |
| model.create_backup() |
|
|
| def send_preview(): |
| if not debug: |
| previews = model.get_previews() |
| c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } ) |
| else: |
| previews = [( 'debug, press update for new', model.debug_one_iter())] |
| c2s.put ( {'op':'show', 'previews': previews} ) |
| e.set() |
|
|
| if model.get_target_iter() != 0: |
| if is_reached_goal: |
| io.log_info('Model already trained to target iteration. You can use preview.') |
| else: |
| io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( model.get_target_iter() ) ) |
| else: |
| io.log_info('Starting. Press "Enter" to stop training and save model.') |
|
|
| last_save_time = time.time() |
|
|
| execute_programs = [ [x[0], x[1], time.time() ] for x in execute_programs ] |
|
|
| for i in itertools.count(0,1): |
| if not debug: |
| cur_time = time.time() |
|
|
| for x in execute_programs: |
| prog_time, prog, last_time = x |
| exec_prog = False |
| if prog_time > 0 and (cur_time - start_time) >= prog_time: |
| x[0] = 0 |
| exec_prog = True |
| elif prog_time < 0 and (cur_time - last_time) >= -prog_time: |
| x[2] = cur_time |
| exec_prog = True |
|
|
| if exec_prog: |
| try: |
| exec(prog) |
| except Exception as e: |
| print("Unable to execute program: %s" % (prog) ) |
|
|
| if not is_reached_goal: |
|
|
| if model.get_iter() == 0: |
| io.log_info("") |
| io.log_info("Trying to do the first iteration. If an error occurs, reduce the model parameters.") |
| io.log_info("") |
| |
| if sys.platform[0:3] == 'win': |
| io.log_info("!!!") |
| io.log_info("Windows 10 users IMPORTANT notice. You should set this setting in order to work correctly.") |
| io.log_info("https://i.imgur.com/B7cmDCB.jpg") |
| io.log_info("!!!") |
|
|
| iter, iter_time = model.train_one_iter() |
|
|
| loss_history = model.get_loss_history() |
| time_str = time.strftime("[%H:%M:%S]") |
| if iter_time >= 10: |
| loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, iter, '{:0.4f}'.format(iter_time) ) |
| else: |
| loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, iter, int(iter_time*1000) ) |
|
|
| if shared_state['after_save']: |
| shared_state['after_save'] = False |
| |
| mean_loss = np.mean ( loss_history[save_iter:iter], axis=0) |
|
|
| for loss_value in mean_loss: |
| loss_string += "[%.4f]" % (loss_value) |
|
|
| io.log_info (loss_string) |
|
|
| save_iter = iter |
| else: |
| for loss_value in loss_history[-1]: |
| loss_string += "[%.4f]" % (loss_value) |
|
|
| if io.is_colab(): |
| io.log_info ('\r' + loss_string, end='') |
| else: |
| io.log_info (loss_string, end='\r') |
|
|
| if model.get_iter() == 1: |
| model_save() |
|
|
| if model.get_target_iter() != 0 and model.is_reached_iter_goal(): |
| io.log_info ('Reached target iteration.') |
| model_save() |
| is_reached_goal = True |
| io.log_info ('You can use preview now.') |
| |
| need_save = False |
| while time.time() - last_save_time >= save_interval_min*60: |
| last_save_time += save_interval_min*60 |
| need_save = True |
| |
| if not is_reached_goal and need_save: |
| model_save() |
| send_preview() |
|
|
| if i==0: |
| if is_reached_goal: |
| model.pass_one_iter() |
| send_preview() |
|
|
| if debug: |
| time.sleep(0.005) |
|
|
| while not s2c.empty(): |
| input = s2c.get() |
| op = input['op'] |
| if op == 'save': |
| model_save() |
| elif op == 'backup': |
| model_backup() |
| elif op == 'preview': |
| if is_reached_goal: |
| model.pass_one_iter() |
| send_preview() |
| elif op == 'close': |
| model_save() |
| i = -1 |
| break |
|
|
| if i == -1: |
| break |
|
|
|
|
|
|
| model.finalize() |
|
|
| except Exception as e: |
| print ('Error: %s' % (str(e))) |
| traceback.print_exc() |
| break |
| c2s.put ( {'op':'close'} ) |
|
|
|
|
|
|
| def main(**kwargs): |
| io.log_info ("Running trainer.\r\n") |
|
|
| no_preview = kwargs.get('no_preview', False) |
|
|
| s2c = queue.Queue() |
| c2s = queue.Queue() |
|
|
| e = threading.Event() |
| thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e), kwargs=kwargs ) |
| thread.start() |
|
|
| e.wait() |
|
|
| if no_preview: |
| while True: |
| if not c2s.empty(): |
| input = c2s.get() |
| op = input.get('op','') |
| if op == 'close': |
| break |
| try: |
| io.process_messages(0.1) |
| except KeyboardInterrupt: |
| s2c.put ( {'op': 'close'} ) |
| else: |
| wnd_name = "Training preview" |
| io.named_window(wnd_name) |
| io.capture_keys(wnd_name) |
|
|
| previews = None |
| loss_history = None |
| selected_preview = 0 |
| update_preview = False |
| is_showing = False |
| is_waiting_preview = False |
| show_last_history_iters_count = 0 |
| iter = 0 |
| while True: |
| if not c2s.empty(): |
| input = c2s.get() |
| op = input['op'] |
| if op == 'show': |
| is_waiting_preview = False |
| loss_history = input['loss_history'] if 'loss_history' in input.keys() else None |
| previews = input['previews'] if 'previews' in input.keys() else None |
| iter = input['iter'] if 'iter' in input.keys() else 0 |
| if previews is not None: |
| max_w = 0 |
| max_h = 0 |
| for (preview_name, preview_rgb) in previews: |
| (h, w, c) = preview_rgb.shape |
| max_h = max (max_h, h) |
| max_w = max (max_w, w) |
|
|
| max_size = 800 |
| if max_h > max_size: |
| max_w = int( max_w / (max_h / max_size) ) |
| max_h = max_size |
|
|
| |
| for preview in previews[:]: |
| (preview_name, preview_rgb) = preview |
| (h, w, c) = preview_rgb.shape |
| if h != max_h or w != max_w: |
| previews.remove(preview) |
| previews.append ( (preview_name, cv2.resize(preview_rgb, (max_w, max_h))) ) |
| selected_preview = selected_preview % len(previews) |
| update_preview = True |
| elif op == 'close': |
| break |
|
|
| if update_preview: |
| update_preview = False |
|
|
| selected_preview_name = previews[selected_preview][0] |
| selected_preview_rgb = previews[selected_preview][1] |
| (h,w,c) = selected_preview_rgb.shape |
|
|
| |
| head_lines = [ |
| '[s]:save [b]:backup [enter]:exit', |
| '[p]:update [space]:next preview [l]:change history range', |
| 'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews) ) |
| ] |
| head_line_height = 15 |
| head_height = len(head_lines) * head_line_height |
| head = np.ones ( (head_height,w,c) ) * 0.1 |
|
|
| for i in range(0, len(head_lines)): |
| t = i*head_line_height |
| b = (i+1)*head_line_height |
| head[t:b, 0:w] += imagelib.get_text_image ( (head_line_height,w,c) , head_lines[i], color=[0.8]*c ) |
|
|
| final = head |
|
|
| if loss_history is not None: |
| if show_last_history_iters_count == 0: |
| loss_history_to_show = loss_history |
| else: |
| loss_history_to_show = loss_history[-show_last_history_iters_count:] |
|
|
| lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, w, c) |
| final = np.concatenate ( [final, lh_img], axis=0 ) |
|
|
| final = np.concatenate ( [final, selected_preview_rgb], axis=0 ) |
| final = np.clip(final, 0, 1) |
|
|
| io.show_image( wnd_name, (final*255).astype(np.uint8) ) |
| is_showing = True |
|
|
| key_events = io.get_key_events(wnd_name) |
| key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False) |
|
|
| if key == ord('\n') or key == ord('\r'): |
| s2c.put ( {'op': 'close'} ) |
| elif key == ord('s'): |
| s2c.put ( {'op': 'save'} ) |
| elif key == ord('b'): |
| s2c.put ( {'op': 'backup'} ) |
| elif key == ord('p'): |
| if not is_waiting_preview: |
| is_waiting_preview = True |
| s2c.put ( {'op': 'preview'} ) |
| elif key == ord('l'): |
| if show_last_history_iters_count == 0: |
| show_last_history_iters_count = 5000 |
| elif show_last_history_iters_count == 5000: |
| show_last_history_iters_count = 10000 |
| elif show_last_history_iters_count == 10000: |
| show_last_history_iters_count = 50000 |
| elif show_last_history_iters_count == 50000: |
| show_last_history_iters_count = 100000 |
| elif show_last_history_iters_count == 100000: |
| show_last_history_iters_count = 0 |
| update_preview = True |
| elif key == ord(' '): |
| selected_preview = (selected_preview + 1) % len(previews) |
| update_preview = True |
|
|
| try: |
| io.process_messages(0.1) |
| except KeyboardInterrupt: |
| s2c.put ( {'op': 'close'} ) |
|
|
| io.destroy_all_windows() |