Diffusers
Safetensors
File size: 4,503 Bytes
87a49e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import sys
import os
import torch
from PIL import Image
import gradio as gr
from glob import glob
from contextlib import nullcontext
from pipeline import Lotus2Pipeline
from diffusers import (
    FlowMatchEulerDiscreteScheduler,
    FluxTransformer2DModel,
)
from infer import (
    load_lora_and_lcm_weights,
    process_single_image
)
from evaluation.evaluation import evaluation_depth, evaluation_normal


pipeline = None
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = torch.bfloat16
task = os.environ.get("TASK_NAME", "depth") # or normal

def load_pipeline():
    global pipeline, device, weight_dtype, task
    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
        'black-forest-labs/FLUX.1-dev', subfolder="scheduler", num_train_timesteps=10
    )
    transformer = FluxTransformer2DModel.from_pretrained(
        'black-forest-labs/FLUX.1-dev', subfolder="transformer", revision=None, variant=None
    )
    transformer.requires_grad_(False)
    transformer.to(device=device, dtype=weight_dtype)
    transformer, local_continuity_module = load_lora_and_lcm_weights(transformer, None, None, None, task)
    pipeline = Lotus2Pipeline.from_pretrained(
        'black-forest-labs/FLUX.1-dev',
        scheduler=noise_scheduler,
        transformer=transformer,
        revision=None,
        variant=None,
        torch_dtype=weight_dtype,
    )
    pipeline.local_continuity_module = local_continuity_module
    pipeline = pipeline.to(device)
    pipeline.set_progress_bar_config(disable=True)

def eval():
    global pipeline, device, weight_dtype, task
    base_test_data_dir = os.environ.get("TEST_DATA_DIR", "datasets/eval")
    output_dir = os.environ.get("OUTPUT_DIR", "outputs/eval")

    def gen_fn(rgb_in):
        if task == "depth":
            rgb_input = rgb_in / 255.0 * 2.0 - 1.0  #  [0, 255] -> [-1, 1]
            output_type = "np"
        elif task == "normal":
            rgb_input = rgb_in
            output_type = "pt"
        else:
            raise ValueError(f"Invalid task name: {task}")

        prediction = pipeline(
            rgb_in=rgb_input, 
            prompt='', 
            num_inference_steps=10,
            output_type=output_type,
            process_res=None
            ).images[0]

        if task == "depth":
            output = prediction.mean(axis=-1)
        elif task == "normal":
            output = (prediction * 2.0 - 1.0).unsqueeze(0) # [0,1] -> [-1,1], (1, 3, h, w)
        return output

    with torch.no_grad():
        if task == 'depth':
            test_data_dir = os.path.join(base_test_data_dir, task)
            test_depth_dataset_configs = {
                "nyuv2": "configs/data_nyu_test.yaml", 
                "kitti": "configs/data_kitti_eigen_test.yaml",
                "scannet": "configs/data_scannet_val.yaml",
                "eth3d": "configs/data_eth3d.yaml",
                "diode": "configs/data_diode_all.yaml",
            }
            for dataset_name, config_path in test_depth_dataset_configs.items():
                eval_dir = os.path.join(output_dir, task, dataset_name)
                test_dataset_config = os.path.join(test_data_dir, config_path)
                alignment_type = "least_square_disparity"
                metric_tracker = evaluation_depth(eval_dir, test_dataset_config, test_data_dir, eval_mode="generate_prediction",
                                                  gen_prediction=gen_fn, pipeline=pipeline, alignment=alignment_type, processing_res=None)
                print(dataset_name,',', 'abs_relative_difference: ', metric_tracker.result()['abs_relative_difference'], 'delta1_acc: ', metric_tracker.result()['delta1_acc'])
        elif task == 'normal':
            test_data_dir = os.path.join(base_test_data_dir, task)
            dataset_split_path = "evaluation/dataset_normal"
            eval_datasets = [ ('nyuv2', 'test'), ('scannet', 'test'), ('ibims', 'ibims'), ('sintel', 'sintel'),  ('oasis', 'val')]
            eval_dir = os.path.join(output_dir, task)
            evaluation_normal(eval_dir, test_data_dir, dataset_split_path, eval_mode="generate_prediction", 
                              gen_prediction=gen_fn, pipeline=pipeline, eval_datasets=eval_datasets, processing_res=None)
        else:
            raise ValueError(f"Not support predicting {task} yet. ")
        
        print('==> Evaluation is done. \n==> Results saved to:', output_dir)


if __name__ == "__main__":
    load_pipeline()
    eval()