qic999's picture
Upload folder using huggingface_hub
74ecfd7 verified
import os
import nibabel as nib
import pandas as pd
import numpy as np
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
# ===== 路径 =====
DATA_ROOT = "/home/shuhan/blobdata/CT-RATE/dataset/"
TRAIN_META = "/home/shuhan/blobdata/CT-RATE/metadata/train_metadata.csv"
VALID_META = "/home/shuhan/blobdata/CT-RATE/metadata/validation_metadata.csv"
CLIP_RANGE = (-1000, 1000)
NUM_WORKERS = max(1, (os.cpu_count() or 8) - 1)
def build_full_path(basename):
"""根据basename构造完整路径"""
# e.g. train_1661_b_1.nii.gz
stem = basename.replace(".nii.gz", "")
parts = stem.split("_") # ['train','1661','b','1']
if len(parts) < 3:
raise ValueError(f"basename格式异常: {basename}")
# train_fixed
folder0 = f"{parts[0]}_fixed"
# train_1661
folder1 = f"{parts[0]}_{parts[1]}"
# train_1661_b
folder2 = f"{parts[0]}_{parts[1]}_{parts[2]}"
return os.path.join(DATA_ROOT, folder0, folder1, folder2, basename)
def process_one(row):
"""处理单个volume"""
basename = row["VolumeName"]
nii_path = build_full_path(basename)
if not os.path.exists(nii_path):
return f"[WARN] 文件不存在: {nii_path}"
try:
nii = nib.load(nii_path)
arr = nii.get_fdata().astype(np.float32)
vmin, vmax = float(arr.min()), float(arr.max())
# 如果已经在范围内,跳过
if vmin >= CLIP_RANGE[0] and vmax <= CLIP_RANGE[1]:
return f"[SKIP] {basename} 已在范围内 ({vmin:.1f},{vmax:.1f})"
# Clip 并保存覆盖
hu = np.clip(arr, CLIP_RANGE[0], CLIP_RANGE[1])
new_img = nib.Nifti1Image(hu, nii.affine, nii.header)
nib.save(new_img, nii_path)
return f"[OK] {basename} -> ({vmin:.1f},{vmax:.1f}) clipped"
except Exception as e:
return f"[ERROR] {nii_path}: {e}"
def process_meta(meta_csv):
df = pd.read_csv(meta_csv)
results = []
with ProcessPoolExecutor(max_workers=NUM_WORKERS) as ex:
futures = {ex.submit(process_one, row): row for _, row in df.iterrows()}
for fut in tqdm(as_completed(futures), total=len(futures), desc=f"Processing {os.path.basename(meta_csv)}"):
results.append(fut.result())
return results
if __name__ == "__main__":
# res_train = process_meta(TRAIN_META)
res_valid = process_meta(VALID_META)
# # 保存日志
# with open("clip_log.txt", "w") as f:
# f.write("\n".join(res_train + res_valid))
# print("完成。日志已保存 clip_log.txt")