perorina commited on
Commit
579bc1b
·
1 Parent(s): 36f7261

Create mbw_util/preset_weights.py

Browse files
Files changed (1) hide show
  1. scripts/mbw_util/preset_weights.py +55 -0
scripts/mbw_util/preset_weights.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ #
3
+ #
4
+ import os
5
+ from csv import DictReader
6
+
7
+ from modules import scripts
8
+
9
+
10
+ CSV_FILE_PATH = "csv/preset.tsv"
11
+ MYPRESET_PATH = "csv/preset_own.tsv"
12
+ HEADER = ["preset_name", "preset_weights"]
13
+ path_root = scripts.basedir()
14
+
15
+
16
+ class PresetWeights():
17
+ def __init__(self):
18
+ self.presets = {}
19
+
20
+ if os.path.exists(os.path.join(path_root, MYPRESET_PATH)):
21
+ with open(os.path.join(path_root, MYPRESET_PATH), "r") as f:
22
+ reader = DictReader(f, delimiter="\t")
23
+ lines_dict = [row for row in reader]
24
+ for line_dict in lines_dict:
25
+ _w = ",".join([f"{x.strip()}" for x in line_dict["preset_weights"].split(",")])
26
+ self.presets.update({line_dict["preset_name"]: _w})
27
+
28
+ with open(os.path.join(path_root, CSV_FILE_PATH), "r") as f:
29
+ reader = DictReader(f, delimiter="\t")
30
+ lines_dict = [row for row in reader]
31
+ for line_dict in lines_dict:
32
+ _w = ",".join([f"{x.strip()}" for x in line_dict["preset_weights"].split(",")])
33
+ self.presets.update({line_dict["preset_name"]: _w})
34
+
35
+ def get_preset_name_list(self):
36
+ return [k for k in self.presets.keys()]
37
+
38
+ def find_weight_by_name(self, preset_name=""):
39
+ if preset_name and preset_name != "" and preset_name in self.presets.keys():
40
+ return self.presets.get(preset_name, ",".join(["0.5" for _ in range(25)]))
41
+ else:
42
+ return ""
43
+
44
+ def find_names_by_weight(self, weights=""):
45
+ if weights and weights != "":
46
+ if weights in self.presets.values():
47
+ return [k for k, v in self.presets.items() if v == weights]
48
+ else:
49
+ _val = ",".join([f"{x.strip()}" for x in weights.split(",")])
50
+ if _val in self.presets.values():
51
+ return [k for k, v in self.presets.items() if v == _val]
52
+ else:
53
+ return []
54
+ else:
55
+ return []