mally-2000 commited on
Commit
d6f6beb
·
verified ·
1 Parent(s): e8c53ea

Make full evaluation optional in infer script

Browse files

Add an --eval flag to infer.py and default to single-sample inference to keep Colab runs fast.

Files changed (2) hide show
  1. README.md +6 -0
  2. infer.py +23 -17
README.md CHANGED
@@ -69,6 +69,12 @@ python infer.py CLDM # SAII-CLDM
69
  `infer.py` uses the bundled Overthrust sample and writes outputs under
70
  `outputs/infer_LDDPM/` or `outputs/infer_CLDM/`.
71
 
 
 
 
 
 
 
72
  ## Overthrust Results
73
 
74
  Impedance-domain metrics on the bundled Overthrust setting:
 
69
  `infer.py` uses the bundled Overthrust sample and writes outputs under
70
  `outputs/infer_LDDPM/` or `outputs/infer_CLDM/`.
71
 
72
+ Add `--eval` to run the full bundled Overthrust evaluation:
73
+
74
+ ```bash
75
+ python infer.py CLDM --eval
76
+ ```
77
+
78
  ## Overthrust Results
79
 
80
  Impedance-domain metrics on the bundled Overthrust setting:
infer.py CHANGED
@@ -1,5 +1,6 @@
1
  from __future__ import annotations
2
 
 
3
  import sys
4
  from pathlib import Path
5
 
@@ -16,11 +17,15 @@ from codes.pipeline import SeismicImpInvCLDMPipeline, SeismicImpInvLDDPMPipeline
16
  from codes.util import OverthrustForwardOperator, ricker_wavelet
17
 
18
 
19
- METHOD = sys.argv[1].upper() if len(sys.argv) > 1 else "LDDPM"
20
- OUT_DIR = REPO_ROOT / "outputs" / f"infer_{METHOD}"
21
  PATCH_INDEX = 0
22
  MODEL_DIR = REPO_ROOT
23
- RUN_EVAL = True
 
 
 
 
 
 
24
 
25
  def save_comparison(dipin, record, target, prediction, output_path):
26
  fig, axes = plt.subplots(1, 4, figsize=(16, 4))
@@ -40,18 +45,19 @@ def save_comparison(dipin, record, target, prediction, output_path):
40
 
41
 
42
  if __name__ == "__main__":
43
- if METHOD not in {"LDDPM", "CLDM"}:
44
- raise ValueError("METHOD must be LDDPM or CLDM. Example: python infer.py CLDM")
 
45
 
46
- OUT_DIR.mkdir(parents=True, exist_ok=True)
47
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
  print(f"Using device: {device}")
49
- print(f"Method: {METHOD}")
50
 
51
  dataset = OverthrustTrueimpDataset(
52
  patch_indices=[PATCH_INDEX],
53
  data_dir=REPO_ROOT / "data",
54
- cache_dir=OUT_DIR / "cache",
55
  )
56
  sample = dataset[0]
57
  dipin = sample["dipin"].unsqueeze(0).to(device)
@@ -59,7 +65,7 @@ if __name__ == "__main__":
59
  image = sample["image"].unsqueeze(0).to(device)
60
  seed = int(sample["seed"])
61
 
62
- if METHOD == "LDDPM":
63
  num_inference_steps = 1000
64
  extra_kwargs = {}
65
  pipe = SeismicImpInvLDDPMPipeline.from_pretrained(
@@ -109,15 +115,15 @@ if __name__ == "__main__":
109
  dipin_np = dipin[0, 0].detach().cpu().numpy()
110
  record_np = record[0, 0].detach().cpu().numpy()
111
 
112
- np.save(OUT_DIR / "prediction.npy", prediction)
113
- np.save(OUT_DIR / "target.npy", target)
114
- save_comparison(dipin_np, record_np, target, prediction, OUT_DIR / "comparison.png")
115
 
116
- print(f"Saved: {OUT_DIR / 'prediction.npy'}")
117
- print(f"Saved: {OUT_DIR / 'target.npy'}")
118
- print(f"Saved: {OUT_DIR / 'comparison.png'}")
119
 
120
- if RUN_EVAL:
121
  from codes.eval_overthrust import evaluate_overthrust
122
 
123
- evaluate_overthrust(pipe, method=METHOD, output_dir=OUT_DIR / "eval")
 
1
  from __future__ import annotations
2
 
3
+ import argparse
4
  import sys
5
  from pathlib import Path
6
 
 
17
  from codes.util import OverthrustForwardOperator, ricker_wavelet
18
 
19
 
 
 
20
  PATCH_INDEX = 0
21
  MODEL_DIR = REPO_ROOT
22
+
23
+
24
+ def parse_args() -> argparse.Namespace:
25
+ parser = argparse.ArgumentParser(description="Run SAII-LDDPM/CLDM inference.")
26
+ parser.add_argument("method", nargs="?", choices=["LDDPM", "CLDM"], default="LDDPM")
27
+ parser.add_argument("--eval", action="store_true", help="Run full Overthrust evaluation after single-sample inference.")
28
+ return parser.parse_args()
29
 
30
  def save_comparison(dipin, record, target, prediction, output_path):
31
  fig, axes = plt.subplots(1, 4, figsize=(16, 4))
 
45
 
46
 
47
  if __name__ == "__main__":
48
+ args = parse_args()
49
+ method = args.method.upper()
50
+ out_dir = REPO_ROOT / "outputs" / f"infer_{method}"
51
 
52
+ out_dir.mkdir(parents=True, exist_ok=True)
53
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
  print(f"Using device: {device}")
55
+ print(f"Method: {method}")
56
 
57
  dataset = OverthrustTrueimpDataset(
58
  patch_indices=[PATCH_INDEX],
59
  data_dir=REPO_ROOT / "data",
60
+ cache_dir=out_dir / "cache",
61
  )
62
  sample = dataset[0]
63
  dipin = sample["dipin"].unsqueeze(0).to(device)
 
65
  image = sample["image"].unsqueeze(0).to(device)
66
  seed = int(sample["seed"])
67
 
68
+ if method == "LDDPM":
69
  num_inference_steps = 1000
70
  extra_kwargs = {}
71
  pipe = SeismicImpInvLDDPMPipeline.from_pretrained(
 
115
  dipin_np = dipin[0, 0].detach().cpu().numpy()
116
  record_np = record[0, 0].detach().cpu().numpy()
117
 
118
+ np.save(out_dir / "prediction.npy", prediction)
119
+ np.save(out_dir / "target.npy", target)
120
+ save_comparison(dipin_np, record_np, target, prediction, out_dir / "comparison.png")
121
 
122
+ print(f"Saved: {out_dir / 'prediction.npy'}")
123
+ print(f"Saved: {out_dir / 'target.npy'}")
124
+ print(f"Saved: {out_dir / 'comparison.png'}")
125
 
126
+ if args.eval:
127
  from codes.eval_overthrust import evaluate_overthrust
128
 
129
+ evaluate_overthrust(pipe, method=method, output_dir=out_dir / "eval")