| import os, re, numpy as np
|
| import matplotlib.pyplot as plt
|
| import matplotlib.pyplot as plt_xsec
|
| from datetime import datetime
|
|
|
| input_log_file = './ewnet_logs_TRANS3_20240708.txt'
|
| flag_all_xsections = True
|
| prev_station = ''
|
|
|
| now = datetime.now()
|
| now_str = now.strftime('%Y%m%d_%H%M')
|
|
|
| label_list = ['pave_layer1', 'pave_layer2', 'pave_layer3', 'pave_layer4', 'cut_ea', 'cut_rr', 'cut_br', 'cut_ditch', 'fill_subbed', 'fill_subbody', 'curb', 'above', 'below', 'pave_int', 'pave_surface', 'pave_subgrade', 'ground', 'pave_bottom', 'rr', 'br', 'slope', 'struct', 'steps']
|
| color_list = [[0.8,0.8,0.8],[0.6,0.6,0.6],[0.4,0.4,0.4],[0.2,0.2,0.2],[0.8,0.4,0.2],[0.8,0.6,0.2],[0.8,0.8,0.2],[0.6,0.8,0.2],[0.3,0.8,0.3],[0.3,0.6,0.3],[0.3,0.4,0.3],[0.0,0.8,0.0],[0.6,0.0,0.0],[0.8,0.0,0.0],[1.0,0.0,0.0],[0.2,0.2,0.6],[0.0,1.0,0.0],[0.2,0.2,1.0],[0.4,0.2,1.0],[0.6,0.2,1.0],[0.2,0.8,0.6],[0.8,0.2,1.0],[1.0,0.2,1.0]]
|
|
|
|
|
| if not os.path.exists('./graph'):
|
| os.makedirs('./graph')
|
|
|
| def draw_colorbox_list():
|
| global label_list, color_list
|
|
|
| fig, ax = plt.subplots(figsize=(9.2, 5))
|
| ax.invert_yaxis()
|
| ax.set_xlim(0, 1.5)
|
| fig.set_size_inches(12, 7)
|
|
|
| token_list = ['item1', 'item2', 'item3', 'item4', 'item5', 'item6']
|
| for i, (colname, color) in enumerate(zip(label_list, color_list)):
|
| width = 1.0 / len(label_list)
|
| widths = [width] * len(token_list)
|
| starts = width * i
|
| rects = ax.barh(token_list, widths, left=starts, height=0.5, label=colname, color=color)
|
|
|
| text_color = 'white' if np.max(color) < 0.4 else 'black'
|
| ax.legend()
|
| plt.savefig('./graph/box_colors.png')
|
| plt.close()
|
|
|
| def output_graph_matrics(index, tag, text):
|
| global label_list, color_list
|
|
|
| prediction = ''
|
| tokens = []
|
| polyline = []
|
| geom_index = text.find('Geom:')
|
| if geom_index >= 0:
|
| pred_label = ''
|
| label_index = text.find('Predicted: ')
|
| if label_index >= 0:
|
| pred = text[label_index + 11:geom_index]
|
| labels = pred.split(', ')
|
| if len(labels) > 0:
|
| prediction = labels[0]
|
| pred_label = labels[0] + '(0.3'
|
|
|
| polyline_index = text.find('Polyline:')
|
| if polyline_index > 0:
|
| pred = text[geom_index + 6:polyline_index - 2]
|
| polyline_text = text[polyline_index + 10:]
|
| polyline = eval(polyline_text)
|
| else:
|
| pred = text[geom_index + 6:]
|
| pred = pred.replace('[', '').replace(']', '')
|
| pred = pred.replace(')', '').replace("'", '')
|
| tokens = pred.split(',')
|
| if len(tokens) <= 1:
|
| tokens = pred.split(' ')
|
| if len(tokens) > 0:
|
| tokens.insert(0, pred_label)
|
| last = tokens[-1]
|
| if len(last) == 0:
|
| tokens.pop()
|
| else:
|
| return
|
|
|
| token_list = [token.split('(')[0] for token in tokens]
|
| token_list = [token.replace(' ', '') for token in token_list]
|
| ratios = [float(token.split('(')[1]) for token in tokens]
|
| results = {token_list[0]: ratios}
|
|
|
| labels = [label.replace(" ", "") for label in list(results.keys())]
|
| data = np.array(list(results.values()))
|
| data_cum = data.cumsum(axis=1)
|
| token_colors = [color_list[label_list.index(label)] for label in token_list]
|
|
|
| global plt_xsec, now_str, flag_all_xsections
|
| if flag_all_xsections == False:
|
| fig, ax = plt.subplots(figsize=(9.2, 5))
|
| ax.invert_yaxis()
|
| ax.xaxis.set_visible(False)
|
| ax.set_xlim(0, np.sum(data, axis=1).max())
|
| fig.set_size_inches(15, 0.5)
|
|
|
| for i, (colname, color) in enumerate(zip(token_list, token_colors)):
|
| widths = data[:, i]
|
| starts = data_cum[:, i] - widths
|
| if i > 0:
|
| starts += 0.02
|
| rects = ax.barh(labels, widths, left=starts, height=0.5, label=colname, color=color)
|
|
|
| if i != 0:
|
| text_color = 'white' if np.max(color) < 0.4 else 'black'
|
| ax.bar_label(rects, label_type='center', color=text_color)
|
| ax.legend(ncols=len(token_list), bbox_to_anchor=(0, 1), loc='lower right', fontsize='small')
|
|
|
| tag = tag.replace(' ', '_')
|
| tag = tag.replace(':', '')
|
|
|
| if text.find('True') > 0:
|
| plt.savefig(f'./graph/box_list_{now_str}_{tag}_{index}_T.png')
|
| else:
|
| plt.savefig(f'./graph/box_list_{now_str}_{tag}_{index}_F.png')
|
| plt.close()
|
| else:
|
| if polyline[0] != polyline[-1]:
|
| polyline.append(polyline[0])
|
| x, y = zip(*polyline)
|
| color = color_list[label_list.index(prediction)]
|
|
|
| plt_xsec.fill(x, y, color=color)
|
| centroid_x = sum(x) / len(x)
|
| centroid_y = sum(y) / len(y)
|
| area = 0.5 * abs(sum(x[i]*y[i+1] - x[i+1]*y[i] for i in range(len(polyline)-1)))
|
|
|
| if prediction.find('pave') < 0:
|
| plt_xsec.text(centroid_x, centroid_y, f'{prediction}={area:.2f}', horizontalalignment='center', verticalalignment='center', fontsize=5, color='black')
|
|
|
| return prediction, area, token_list
|
|
|
| output_stations = ['4+440.00000', '3+780.00000', '3+800.00000', '3+880.00000', '3+940.00000']
|
| def output_logs(tag, equal='none'):
|
| global input_log_file, plt_xsec, now_str, prev_station, flag_all_xsection, output_stations
|
|
|
| text_list = []
|
| logs = []
|
|
|
| with open(input_log_file, 'r') as file:
|
| for index, label in enumerate(label_list):
|
| file.seek(0)
|
| for line in file:
|
| if flag_all_xsections == False and line.find(tag) < 0:
|
| continue
|
| tag_model = tag.split(' ')[0]
|
| if flag_all_xsections == True and line.find(tag_model) < 0:
|
| continue
|
| if flag_all_xsections == False and line.find('Label: ' + label) < 0:
|
| continue
|
| line = line.replace('\n', '')
|
| if equal == 'none':
|
| text_list.append(line)
|
| elif line.find(equal) > 0:
|
| text_list.append(line)
|
| if flag_all_xsections == False:
|
| break
|
| if flag_all_xsections:
|
| break
|
|
|
| if len(text_list) == 0:
|
| return logs
|
|
|
| def extract_station(text):
|
| sta_index = text.find('Station:') + 9
|
| end_index = text.find(',', sta_index)
|
| return text[sta_index:end_index] if end_index != -1 else text[sta_index:]
|
|
|
| text_list = sorted(text_list, key=extract_station)
|
| station = ''
|
| for index, text in enumerate(text_list):
|
| sta_index = text.find('Station:')
|
| equal_index = text.find('Equal: ')
|
| equal_check = 'T' if text.find('True') > 0 else 'F'
|
|
|
| if sta_index > 0 and equal_index > 0:
|
| station = text[sta_index + 9:equal_index-2]
|
| print(station)
|
|
|
| try:
|
| if len(output_stations) and output_stations.index(station) < 0:
|
| continue
|
| except Exception as e:
|
| continue
|
|
|
| if prev_station != station:
|
| if len(prev_station) > 0:
|
| plt_xsec.savefig(f'./graph/polygon_{now_str}_{tag}_{prev_station}_{equal_check}.png', dpi=300)
|
| plt_xsec.close()
|
|
|
| plt_xsec.figure()
|
| plt_xsec.gca().set_xlim([-60, 60])
|
| plt_xsec.gca().axis('equal')
|
| plt_xsec.gca().text(0, 0, f'{station}', fontsize=12, color='black')
|
|
|
| prev_station = station
|
|
|
| text = text.replace('\n', '')
|
| label, area, tokens = output_graph_matrics(index, tag, text)
|
| log = {
|
| 'index': index,
|
| 'station': station,
|
| 'label': label,
|
| 'area': area,
|
| 'tokens': tokens
|
| }
|
| logs.append(log)
|
|
|
| if index == len(text_list) - 1:
|
| plt_xsec.savefig(f'./graph/polygon_{now_str}_{tag}_{prev_station}_{equal_check}.png', dpi=300)
|
| plt_xsec.close()
|
|
|
| return logs
|
|
|
| def main():
|
| draw_colorbox_list()
|
|
|
| summary_log_file = open('./graph/summary_log.csv', 'a')
|
| if summary_log_file is None:
|
| return
|
| summary_log_file.write(f'model, ground true, length, ground false, length\n')
|
|
|
| tags = ['MLP [128, 64, 32]', 'MLP [64, 128, 64]', 'MLP [64, 128, 64, 32]', 'LSTM [128]', 'LSTM [128, 64, 32]', 'LSTM [256, 128, 64]', 'transformer 32', 'transformer 64', 'transformer 128', 'BERT']
|
| for tag in tags:
|
| print(tag)
|
| if len(output_stations) > 0:
|
| logs1 = output_logs(tag,)
|
| continue
|
|
|
| logs1 = output_logs(tag, 'Equal: True')
|
| logs2 = output_logs(tag, 'Equal: False')
|
| if len(logs1) == 0 or len(logs2) == 0:
|
| continue
|
| area1 = area2 = 0
|
| area1 += sum([log['area'] for log in logs1])
|
| area2 += sum([log['area'] for log in logs2])
|
| log_record = f'{tag}, {area1}, {len(logs1)}, {area2}, {len(logs2)}'
|
| summary_log_file.write(f'{log_record}\n')
|
|
|
| if flag_all_xsections:
|
| break
|
|
|
| summary_log_file.close()
|
|
|
| if __name__ == '__main__':
|
| main()
|
|
|