nocapdev commited on
Commit
9d5974d
Β·
verified Β·
1 Parent(s): 3137bdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +392 -390
app.py CHANGED
@@ -1,391 +1,393 @@
1
- import os
2
- from os.path import join as pjoin
3
- import gradio as gr
4
- import torch
5
- import torch.nn.functional as F
6
- import numpy as np
7
- from torch.distributions.categorical import Categorical
8
-
9
- from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer
10
- from models.vq.model import RVQVAE, LengthEstimator
11
- from utils.get_opt import get_opt
12
- from utils.fixseed import fixseed
13
- from visualization.joints2bvh import Joint2BVHConvertor
14
- from utils.motion_process import recover_from_ric
15
- from utils.plot_script import plot_3d_motion
16
- from utils.paramUtil import t2m_kinematic_chain
17
-
18
- clip_version = 'ViT-B/32'
19
-
20
- class MotionGenerator:
21
- def __init__(self, checkpoints_dir, dataset_name, model_name, res_name, vq_name, device='cuda'):
22
- self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
23
- self.dataset_name = dataset_name
24
- self.dim_pose = 251 if dataset_name == 'kit' else 263
25
- self.nb_joints = 21 if dataset_name == 'kit' else 22
26
-
27
- # Load models
28
- print("Loading models...")
29
- self.vq_model, self.vq_opt = self._load_vq_model(checkpoints_dir, dataset_name, vq_name)
30
- self.t2m_transformer = self._load_trans_model(checkpoints_dir, dataset_name, model_name)
31
- self.res_model = self._load_res_model(checkpoints_dir, dataset_name, res_name, self.vq_opt)
32
- self.length_estimator = self._load_len_estimator(checkpoints_dir, dataset_name)
33
-
34
- # Set to eval mode
35
- self.vq_model.eval()
36
- self.t2m_transformer.eval()
37
- self.res_model.eval()
38
- self.length_estimator.eval()
39
-
40
- # Load normalization stats
41
- meta_dir = pjoin(checkpoints_dir, dataset_name, vq_name, 'meta')
42
- self.mean = np.load(pjoin(meta_dir, 'mean.npy'))
43
- self.std = np.load(pjoin(meta_dir, 'std.npy'))
44
-
45
- self.kinematic_chain = t2m_kinematic_chain
46
- self.converter = Joint2BVHConvertor()
47
-
48
- print("Models loaded successfully!")
49
-
50
- def _load_vq_model(self, checkpoints_dir, dataset_name, vq_name):
51
- vq_opt_path = pjoin(checkpoints_dir, dataset_name, vq_name, 'opt.txt')
52
- vq_opt = get_opt(vq_opt_path, device=self.device)
53
- vq_opt.dim_pose = self.dim_pose
54
-
55
- vq_model = RVQVAE(vq_opt,
56
- vq_opt.dim_pose,
57
- vq_opt.nb_code,
58
- vq_opt.code_dim,
59
- vq_opt.output_emb_width,
60
- vq_opt.down_t,
61
- vq_opt.stride_t,
62
- vq_opt.width,
63
- vq_opt.depth,
64
- vq_opt.dilation_growth_rate,
65
- vq_opt.vq_act,
66
- vq_opt.vq_norm)
67
-
68
- ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, vq_name, 'model', 'net_best_fid.tar'),
69
- map_location=self.device)
70
- model_key = 'vq_model' if 'vq_model' in ckpt else 'net'
71
- vq_model.load_state_dict(ckpt[model_key])
72
- vq_model.to(self.device)
73
-
74
- return vq_model, vq_opt
75
-
76
- def _load_trans_model(self, checkpoints_dir, dataset_name, model_name):
77
- model_opt_path = pjoin(checkpoints_dir, dataset_name, model_name, 'opt.txt')
78
- model_opt = get_opt(model_opt_path, device=self.device)
79
-
80
- model_opt.num_tokens = self.vq_opt.nb_code
81
- model_opt.num_quantizers = self.vq_opt.num_quantizers
82
- model_opt.code_dim = self.vq_opt.code_dim
83
-
84
- # Set default values for missing attributes
85
- if not hasattr(model_opt, 'latent_dim'):
86
- model_opt.latent_dim = 384
87
- if not hasattr(model_opt, 'ff_size'):
88
- model_opt.ff_size = 1024
89
- if not hasattr(model_opt, 'n_layers'):
90
- model_opt.n_layers = 8
91
- if not hasattr(model_opt, 'n_heads'):
92
- model_opt.n_heads = 6
93
- if not hasattr(model_opt, 'dropout'):
94
- model_opt.dropout = 0.1
95
- if not hasattr(model_opt, 'cond_drop_prob'):
96
- model_opt.cond_drop_prob = 0.1
97
-
98
- t2m_transformer = MaskTransformer(code_dim=model_opt.code_dim,
99
- cond_mode='text',
100
- latent_dim=model_opt.latent_dim,
101
- ff_size=model_opt.ff_size,
102
- num_layers=model_opt.n_layers,
103
- num_heads=model_opt.n_heads,
104
- dropout=model_opt.dropout,
105
- clip_dim=512,
106
- cond_drop_prob=model_opt.cond_drop_prob,
107
- clip_version=clip_version,
108
- opt=model_opt)
109
-
110
- ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, model_name, 'model', 'latest.tar'),
111
- map_location=self.device)
112
- model_key = 't2m_transformer' if 't2m_transformer' in ckpt else 'trans'
113
- t2m_transformer.load_state_dict(ckpt[model_key], strict=False)
114
- t2m_transformer.to(self.device)
115
-
116
- return t2m_transformer
117
-
118
- def _load_res_model(self, checkpoints_dir, dataset_name, res_name, vq_opt):
119
- res_opt_path = pjoin(checkpoints_dir, dataset_name, res_name, 'opt.txt')
120
- res_opt = get_opt(res_opt_path, device=self.device)
121
-
122
- # The res_name appears to be the same as vq_name, so res_opt is actually vq_opt
123
- # We need to use proper model architecture parameters
124
- res_opt.num_quantizers = vq_opt.num_quantizers
125
- res_opt.num_tokens = vq_opt.nb_code
126
-
127
- # Set architecture parameters for ResidualTransformer
128
- # These should match the main transformer architecture
129
- res_opt.latent_dim = 384 # Match with main transformer
130
- res_opt.ff_size = 1024
131
- res_opt.n_layers = 9 # Typically slightly more layers for residual
132
- res_opt.n_heads = 6
133
- res_opt.dropout = 0.1
134
- res_opt.cond_drop_prob = 0.1
135
- res_opt.share_weight = False
136
-
137
- print(f"ResidualTransformer config - latent_dim: {res_opt.latent_dim}, ff_size: {res_opt.ff_size}, nlayers: {res_opt.n_layers}, nheads: {res_opt.n_heads}, dropout: {res_opt.dropout}")
138
-
139
- res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim,
140
- cond_mode='text',
141
- latent_dim=res_opt.latent_dim,
142
- ff_size=res_opt.ff_size,
143
- num_layers=res_opt.n_layers,
144
- num_heads=res_opt.n_heads,
145
- dropout=res_opt.dropout,
146
- clip_dim=512,
147
- shared_codebook=vq_opt.shared_codebook,
148
- cond_drop_prob=res_opt.cond_drop_prob,
149
- share_weight=res_opt.share_weight,
150
- clip_version=clip_version,
151
- opt=res_opt)
152
-
153
- ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, res_name, 'model', 'net_best_fid.tar'),
154
- map_location=self.device)
155
-
156
- # Debug: check available keys
157
- print(f"Available checkpoint keys: {ckpt.keys()}")
158
-
159
- # Try different possible keys for the model state dict
160
- model_key = None
161
- for key in ['res_transformer', 'trans', 'net', 'model', 'state_dict']:
162
- if key in ckpt:
163
- model_key = key
164
- break
165
-
166
- if model_key:
167
- print(f"Loading ResidualTransformer from key: {model_key}")
168
- res_transformer.load_state_dict(ckpt[model_key], strict=False)
169
- else:
170
- print("Warning: Could not find model weights in checkpoint. Available keys:", list(ckpt.keys()))
171
- # If this is actually a VQ model checkpoint, we might need to skip loading or handle differently
172
- if 'vq_model' in ckpt or 'net' in ckpt:
173
- print("This appears to be a VQ model checkpoint, not a ResidualTransformer checkpoint.")
174
- print("Skipping weight loading - using randomly initialized ResidualTransformer.")
175
-
176
- res_transformer.to(self.device)
177
-
178
- return res_transformer
179
-
180
- def _load_len_estimator(self, checkpoints_dir, dataset_name):
181
- model = LengthEstimator(512, 50)
182
- ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, 'length_estimator', 'model', 'finest.tar'),
183
- map_location=self.device)
184
- model.load_state_dict(ckpt['estimator'])
185
- model.to(self.device)
186
- return model
187
-
188
- def inv_transform(self, data):
189
- return data * self.std + self.mean
190
-
191
- @torch.no_grad()
192
- def generate(self, text_prompt, motion_length=0, time_steps=18, cond_scale=4,
193
- temperature=1, topkr=0.9, gumbel_sample=True, seed=42):
194
- """
195
- Generate motion from text prompt
196
-
197
- Args:
198
- text_prompt: Text description of the motion
199
- motion_length: Desired motion length (0 for auto-estimation)
200
- time_steps: Number of denoising steps
201
- cond_scale: Classifier-free guidance scale
202
- temperature: Sampling temperature
203
- topkr: Top-k filtering threshold
204
- gumbel_sample: Whether to use Gumbel sampling
205
- seed: Random seed
206
- """
207
- fixseed(seed)
208
-
209
- # Convert motion_length to int if needed
210
- if isinstance(motion_length, float):
211
- motion_length = int(motion_length)
212
-
213
- # Estimate length if not provided
214
- if motion_length == 0:
215
- text_embedding = self.t2m_transformer.encode_text([text_prompt])
216
- pred_dis = self.length_estimator(text_embedding)
217
- probs = F.softmax(pred_dis, dim=-1)
218
- token_lens = Categorical(probs).sample()
219
- else:
220
- token_lens = torch.LongTensor([motion_length // 4]).to(self.device)
221
-
222
- m_length = token_lens * 4
223
-
224
- # Generate motion tokens
225
- mids = self.t2m_transformer.generate([text_prompt], token_lens,
226
- timesteps=int(time_steps),
227
- cond_scale=float(cond_scale),
228
- temperature=float(temperature),
229
- topk_filter_thres=float(topkr),
230
- gsample=gumbel_sample)
231
-
232
- # Refine with residual transformer
233
- mids = self.res_model.generate(mids, [text_prompt], token_lens,
234
- temperature=1, cond_scale=5)
235
-
236
- # Decode to motion
237
- pred_motions = self.vq_model.forward_decoder(mids)
238
- pred_motions = pred_motions.detach().cpu().numpy()
239
-
240
- # Denormalize
241
- data = self.inv_transform(pred_motions)
242
- joint_data = data[0, :m_length[0]]
243
-
244
- # Recover 3D joints
245
- joint = recover_from_ric(torch.from_numpy(joint_data).float(), self.nb_joints).numpy()
246
-
247
- return joint, int(m_length[0].item())
248
-
249
-
250
- def create_gradio_interface(generator, output_dir='./gradio_outputs'):
251
- os.makedirs(output_dir, exist_ok=True)
252
-
253
- def generate_motion(text_prompt):
254
- try:
255
- # Use default parameters for simplicity
256
- motion_length = 0 # Auto-estimate
257
- time_steps = 18
258
- cond_scale = 4.0
259
- temperature = 1.0
260
- topkr = 0.9
261
- use_gumbel = True
262
- seed = 42
263
- use_ik = True
264
-
265
- # Generate motion
266
- joint, actual_length = generator.generate(
267
- text_prompt,
268
- motion_length,
269
- time_steps,
270
- cond_scale,
271
- temperature,
272
- topkr,
273
- use_gumbel,
274
- seed
275
- )
276
-
277
- # Save BVH and video
278
- timestamp = str(np.random.randint(100000))
279
- video_path = pjoin(output_dir, f'motion_{timestamp}.mp4')
280
-
281
- # Convert to BVH with foot IK
282
- _, joint_processed = generator.converter.convert(
283
- joint, filename=None, iterations=100, foot_ik=True
284
- )
285
-
286
- # Create video
287
- plot_3d_motion(video_path, generator.kinematic_chain, joint_processed,
288
- title=text_prompt, fps=20)
289
-
290
- return video_path
291
-
292
- except Exception as e:
293
- import traceback
294
- error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
295
- print(error_msg)
296
- return None
297
-
298
- # Create Gradio interface with Blocks for custom layout
299
- with gr.Blocks(theme=gr.themes.Base(
300
- primary_hue="blue",
301
- secondary_hue="gray",
302
- ).set(
303
- body_background_fill="*neutral_950",
304
- body_background_fill_dark="*neutral_950",
305
- background_fill_primary="*neutral_900",
306
- background_fill_primary_dark="*neutral_900",
307
- background_fill_secondary="*neutral_800",
308
- background_fill_secondary_dark="*neutral_800",
309
- block_background_fill="*neutral_900",
310
- block_background_fill_dark="*neutral_900",
311
- input_background_fill="*neutral_800",
312
- input_background_fill_dark="*neutral_800",
313
- button_primary_background_fill="*primary_600",
314
- button_primary_background_fill_dark="*primary_600",
315
- button_primary_text_color="white",
316
- button_primary_text_color_dark="white",
317
- block_label_text_color="*neutral_200",
318
- block_label_text_color_dark="*neutral_200",
319
- body_text_color="*neutral_200",
320
- body_text_color_dark="*neutral_200",
321
- input_placeholder_color="*neutral_500",
322
- input_placeholder_color_dark="*neutral_500",
323
- ),
324
- css="""
325
- footer {display: none !important;}
326
- .video-fixed-height {
327
- height: 600px !important;
328
- }
329
- .video-fixed-height video {
330
- max-height: 600px !important;
331
- object-fit: contain !important;
332
- }
333
- """) as demo:
334
-
335
- gr.Markdown("# 🎭 Text-to-Motion Generator")
336
- gr.Markdown("Generate 3D human motion animations from text descriptions")
337
-
338
- with gr.Row():
339
- with gr.Column():
340
- text_input = gr.Textbox(
341
- label="Describe the motion you want to generate",
342
- placeholder="e.g., 'a person walks forward and waves'",
343
- lines=3
344
- )
345
- submit_btn = gr.Button("Generate Motion", variant="primary")
346
-
347
- gr.Examples(
348
- examples=[
349
- ["a person walks forward"],
350
- ["a person jumps in place"],
351
- ["someone performs a dance move"],
352
- ["a person sits down on a chair"],
353
- ["a person runs and then stops"],
354
- ],
355
- inputs=text_input,
356
- label="Try these examples"
357
- )
358
-
359
- with gr.Column():
360
- video_output = gr.Video(label="Generated Motion", elem_classes="video-fixed-height")
361
-
362
- submit_btn.click(
363
- fn=generate_motion,
364
- inputs=text_input,
365
- outputs=video_output
366
- )
367
-
368
- return demo
369
-
370
-
371
- if __name__ == '__main__':
372
- # Configuration
373
- CHECKPOINTS_DIR = './checkpoints'
374
- DATASET_NAME = 't2m' # or 'kit'
375
- MODEL_NAME = 't2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns'
376
- RES_NAME = 'rvq_nq6_dc512_nc512_noshare_qdp0.2'
377
- VQ_NAME = 'rvq_nq6_dc512_nc512_noshare_qdp0.2'
378
-
379
- # Initialize generator
380
- generator = MotionGenerator(
381
- checkpoints_dir=CHECKPOINTS_DIR,
382
- dataset_name=DATASET_NAME,
383
- model_name=MODEL_NAME,
384
- res_name=RES_NAME,
385
- vq_name=VQ_NAME,
386
- device='cuda'
387
- )
388
-
389
- # Create and launch Gradio interface
390
- demo = create_gradio_interface(generator)
 
 
391
  demo.launch(share=True, server_name="0.0.0.0", server_port=7860)
 
1
+ import os
2
+ os.environ['MKL_THREADING_LAYER'] = 'GNU'
3
+ import torch
4
+ from os.path import join as pjoin
5
+ import gradio as gr
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from torch.distributions.categorical import Categorical
10
+
11
+ from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer
12
+ from models.vq.model import RVQVAE, LengthEstimator
13
+ from utils.get_opt import get_opt
14
+ from utils.fixseed import fixseed
15
+ from visualization.joints2bvh import Joint2BVHConvertor
16
+ from utils.motion_process import recover_from_ric
17
+ from utils.plot_script import plot_3d_motion
18
+ from utils.paramUtil import t2m_kinematic_chain
19
+
20
+ clip_version = 'ViT-B/32'
21
+
22
+ class MotionGenerator:
23
+ def __init__(self, checkpoints_dir, dataset_name, model_name, res_name, vq_name, device='cuda'):
24
+ self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
25
+ self.dataset_name = dataset_name
26
+ self.dim_pose = 251 if dataset_name == 'kit' else 263
27
+ self.nb_joints = 21 if dataset_name == 'kit' else 22
28
+
29
+ # Load models
30
+ print("Loading models...")
31
+ self.vq_model, self.vq_opt = self._load_vq_model(checkpoints_dir, dataset_name, vq_name)
32
+ self.t2m_transformer = self._load_trans_model(checkpoints_dir, dataset_name, model_name)
33
+ self.res_model = self._load_res_model(checkpoints_dir, dataset_name, res_name, self.vq_opt)
34
+ self.length_estimator = self._load_len_estimator(checkpoints_dir, dataset_name)
35
+
36
+ # Set to eval mode
37
+ self.vq_model.eval()
38
+ self.t2m_transformer.eval()
39
+ self.res_model.eval()
40
+ self.length_estimator.eval()
41
+
42
+ # Load normalization stats
43
+ meta_dir = pjoin(checkpoints_dir, dataset_name, vq_name, 'meta')
44
+ self.mean = np.load(pjoin(meta_dir, 'mean.npy'))
45
+ self.std = np.load(pjoin(meta_dir, 'std.npy'))
46
+
47
+ self.kinematic_chain = t2m_kinematic_chain
48
+ self.converter = Joint2BVHConvertor()
49
+
50
+ print("Models loaded successfully!")
51
+
52
+ def _load_vq_model(self, checkpoints_dir, dataset_name, vq_name):
53
+ vq_opt_path = pjoin(checkpoints_dir, dataset_name, vq_name, 'opt.txt')
54
+ vq_opt = get_opt(vq_opt_path, device=self.device)
55
+ vq_opt.dim_pose = self.dim_pose
56
+
57
+ vq_model = RVQVAE(vq_opt,
58
+ vq_opt.dim_pose,
59
+ vq_opt.nb_code,
60
+ vq_opt.code_dim,
61
+ vq_opt.output_emb_width,
62
+ vq_opt.down_t,
63
+ vq_opt.stride_t,
64
+ vq_opt.width,
65
+ vq_opt.depth,
66
+ vq_opt.dilation_growth_rate,
67
+ vq_opt.vq_act,
68
+ vq_opt.vq_norm)
69
+
70
+ ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, vq_name, 'model', 'net_best_fid.tar'),
71
+ map_location=self.device)
72
+ model_key = 'vq_model' if 'vq_model' in ckpt else 'net'
73
+ vq_model.load_state_dict(ckpt[model_key])
74
+ vq_model.to(self.device)
75
+
76
+ return vq_model, vq_opt
77
+
78
+ def _load_trans_model(self, checkpoints_dir, dataset_name, model_name):
79
+ model_opt_path = pjoin(checkpoints_dir, dataset_name, model_name, 'opt.txt')
80
+ model_opt = get_opt(model_opt_path, device=self.device)
81
+
82
+ model_opt.num_tokens = self.vq_opt.nb_code
83
+ model_opt.num_quantizers = self.vq_opt.num_quantizers
84
+ model_opt.code_dim = self.vq_opt.code_dim
85
+
86
+ # Set default values for missing attributes
87
+ if not hasattr(model_opt, 'latent_dim'):
88
+ model_opt.latent_dim = 384
89
+ if not hasattr(model_opt, 'ff_size'):
90
+ model_opt.ff_size = 1024
91
+ if not hasattr(model_opt, 'n_layers'):
92
+ model_opt.n_layers = 8
93
+ if not hasattr(model_opt, 'n_heads'):
94
+ model_opt.n_heads = 6
95
+ if not hasattr(model_opt, 'dropout'):
96
+ model_opt.dropout = 0.1
97
+ if not hasattr(model_opt, 'cond_drop_prob'):
98
+ model_opt.cond_drop_prob = 0.1
99
+
100
+ t2m_transformer = MaskTransformer(code_dim=model_opt.code_dim,
101
+ cond_mode='text',
102
+ latent_dim=model_opt.latent_dim,
103
+ ff_size=model_opt.ff_size,
104
+ num_layers=model_opt.n_layers,
105
+ num_heads=model_opt.n_heads,
106
+ dropout=model_opt.dropout,
107
+ clip_dim=512,
108
+ cond_drop_prob=model_opt.cond_drop_prob,
109
+ clip_version=clip_version,
110
+ opt=model_opt)
111
+
112
+ ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, model_name, 'model', 'latest.tar'),
113
+ map_location=self.device)
114
+ model_key = 't2m_transformer' if 't2m_transformer' in ckpt else 'trans'
115
+ t2m_transformer.load_state_dict(ckpt[model_key], strict=False)
116
+ t2m_transformer.to(self.device)
117
+
118
+ return t2m_transformer
119
+
120
+ def _load_res_model(self, checkpoints_dir, dataset_name, res_name, vq_opt):
121
+ res_opt_path = pjoin(checkpoints_dir, dataset_name, res_name, 'opt.txt')
122
+ res_opt = get_opt(res_opt_path, device=self.device)
123
+
124
+ # The res_name appears to be the same as vq_name, so res_opt is actually vq_opt
125
+ # We need to use proper model architecture parameters
126
+ res_opt.num_quantizers = vq_opt.num_quantizers
127
+ res_opt.num_tokens = vq_opt.nb_code
128
+
129
+ # Set architecture parameters for ResidualTransformer
130
+ # These should match the main transformer architecture
131
+ res_opt.latent_dim = 384 # Match with main transformer
132
+ res_opt.ff_size = 1024
133
+ res_opt.n_layers = 9 # Typically slightly more layers for residual
134
+ res_opt.n_heads = 6
135
+ res_opt.dropout = 0.1
136
+ res_opt.cond_drop_prob = 0.1
137
+ res_opt.share_weight = False
138
+
139
+ print(f"ResidualTransformer config - latent_dim: {res_opt.latent_dim}, ff_size: {res_opt.ff_size}, nlayers: {res_opt.n_layers}, nheads: {res_opt.n_heads}, dropout: {res_opt.dropout}")
140
+
141
+ res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim,
142
+ cond_mode='text',
143
+ latent_dim=res_opt.latent_dim,
144
+ ff_size=res_opt.ff_size,
145
+ num_layers=res_opt.n_layers,
146
+ num_heads=res_opt.n_heads,
147
+ dropout=res_opt.dropout,
148
+ clip_dim=512,
149
+ shared_codebook=vq_opt.shared_codebook,
150
+ cond_drop_prob=res_opt.cond_drop_prob,
151
+ share_weight=res_opt.share_weight,
152
+ clip_version=clip_version,
153
+ opt=res_opt)
154
+
155
+ ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, res_name, 'model', 'net_best_fid.tar'),
156
+ map_location=self.device)
157
+
158
+ # Debug: check available keys
159
+ print(f"Available checkpoint keys: {ckpt.keys()}")
160
+
161
+ # Try different possible keys for the model state dict
162
+ model_key = None
163
+ for key in ['res_transformer', 'trans', 'net', 'model', 'state_dict']:
164
+ if key in ckpt:
165
+ model_key = key
166
+ break
167
+
168
+ if model_key:
169
+ print(f"Loading ResidualTransformer from key: {model_key}")
170
+ res_transformer.load_state_dict(ckpt[model_key], strict=False)
171
+ else:
172
+ print("Warning: Could not find model weights in checkpoint. Available keys:", list(ckpt.keys()))
173
+ # If this is actually a VQ model checkpoint, we might need to skip loading or handle differently
174
+ if 'vq_model' in ckpt or 'net' in ckpt:
175
+ print("This appears to be a VQ model checkpoint, not a ResidualTransformer checkpoint.")
176
+ print("Skipping weight loading - using randomly initialized ResidualTransformer.")
177
+
178
+ res_transformer.to(self.device)
179
+
180
+ return res_transformer
181
+
182
+ def _load_len_estimator(self, checkpoints_dir, dataset_name):
183
+ model = LengthEstimator(512, 50)
184
+ ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, 'length_estimator', 'model', 'finest.tar'),
185
+ map_location=self.device)
186
+ model.load_state_dict(ckpt['estimator'])
187
+ model.to(self.device)
188
+ return model
189
+
190
+ def inv_transform(self, data):
191
+ return data * self.std + self.mean
192
+
193
+ @torch.no_grad()
194
+ def generate(self, text_prompt, motion_length=0, time_steps=18, cond_scale=4,
195
+ temperature=1, topkr=0.9, gumbel_sample=True, seed=42):
196
+ """
197
+ Generate motion from text prompt
198
+
199
+ Args:
200
+ text_prompt: Text description of the motion
201
+ motion_length: Desired motion length (0 for auto-estimation)
202
+ time_steps: Number of denoising steps
203
+ cond_scale: Classifier-free guidance scale
204
+ temperature: Sampling temperature
205
+ topkr: Top-k filtering threshold
206
+ gumbel_sample: Whether to use Gumbel sampling
207
+ seed: Random seed
208
+ """
209
+ fixseed(seed)
210
+
211
+ # Convert motion_length to int if needed
212
+ if isinstance(motion_length, float):
213
+ motion_length = int(motion_length)
214
+
215
+ # Estimate length if not provided
216
+ if motion_length == 0:
217
+ text_embedding = self.t2m_transformer.encode_text([text_prompt])
218
+ pred_dis = self.length_estimator(text_embedding)
219
+ probs = F.softmax(pred_dis, dim=-1)
220
+ token_lens = Categorical(probs).sample()
221
+ else:
222
+ token_lens = torch.LongTensor([motion_length // 4]).to(self.device)
223
+
224
+ m_length = token_lens * 4
225
+
226
+ # Generate motion tokens
227
+ mids = self.t2m_transformer.generate([text_prompt], token_lens,
228
+ timesteps=int(time_steps),
229
+ cond_scale=float(cond_scale),
230
+ temperature=float(temperature),
231
+ topk_filter_thres=float(topkr),
232
+ gsample=gumbel_sample)
233
+
234
+ # Refine with residual transformer
235
+ mids = self.res_model.generate(mids, [text_prompt], token_lens,
236
+ temperature=1, cond_scale=5)
237
+
238
+ # Decode to motion
239
+ pred_motions = self.vq_model.forward_decoder(mids)
240
+ pred_motions = pred_motions.detach().cpu().numpy()
241
+
242
+ # Denormalize
243
+ data = self.inv_transform(pred_motions)
244
+ joint_data = data[0, :m_length[0]]
245
+
246
+ # Recover 3D joints
247
+ joint = recover_from_ric(torch.from_numpy(joint_data).float(), self.nb_joints).numpy()
248
+
249
+ return joint, int(m_length[0].item())
250
+
251
+
252
+ def create_gradio_interface(generator, output_dir='./gradio_outputs'):
253
+ os.makedirs(output_dir, exist_ok=True)
254
+
255
+ def generate_motion(text_prompt):
256
+ try:
257
+ # Use default parameters for simplicity
258
+ motion_length = 0 # Auto-estimate
259
+ time_steps = 18
260
+ cond_scale = 4.0
261
+ temperature = 1.0
262
+ topkr = 0.9
263
+ use_gumbel = True
264
+ seed = 42
265
+ use_ik = True
266
+
267
+ # Generate motion
268
+ joint, actual_length = generator.generate(
269
+ text_prompt,
270
+ motion_length,
271
+ time_steps,
272
+ cond_scale,
273
+ temperature,
274
+ topkr,
275
+ use_gumbel,
276
+ seed
277
+ )
278
+
279
+ # Save BVH and video
280
+ timestamp = str(np.random.randint(100000))
281
+ video_path = pjoin(output_dir, f'motion_{timestamp}.mp4')
282
+
283
+ # Convert to BVH with foot IK
284
+ _, joint_processed = generator.converter.convert(
285
+ joint, filename=None, iterations=100, foot_ik=True
286
+ )
287
+
288
+ # Create video
289
+ plot_3d_motion(video_path, generator.kinematic_chain, joint_processed,
290
+ title=text_prompt, fps=20)
291
+
292
+ return video_path
293
+
294
+ except Exception as e:
295
+ import traceback
296
+ error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
297
+ print(error_msg)
298
+ return None
299
+
300
+ # Create Gradio interface with Blocks for custom layout
301
+ with gr.Blocks(theme=gr.themes.Base(
302
+ primary_hue="blue",
303
+ secondary_hue="gray",
304
+ ).set(
305
+ body_background_fill="*neutral_950",
306
+ body_background_fill_dark="*neutral_950",
307
+ background_fill_primary="*neutral_900",
308
+ background_fill_primary_dark="*neutral_900",
309
+ background_fill_secondary="*neutral_800",
310
+ background_fill_secondary_dark="*neutral_800",
311
+ block_background_fill="*neutral_900",
312
+ block_background_fill_dark="*neutral_900",
313
+ input_background_fill="*neutral_800",
314
+ input_background_fill_dark="*neutral_800",
315
+ button_primary_background_fill="*primary_600",
316
+ button_primary_background_fill_dark="*primary_600",
317
+ button_primary_text_color="white",
318
+ button_primary_text_color_dark="white",
319
+ block_label_text_color="*neutral_200",
320
+ block_label_text_color_dark="*neutral_200",
321
+ body_text_color="*neutral_200",
322
+ body_text_color_dark="*neutral_200",
323
+ input_placeholder_color="*neutral_500",
324
+ input_placeholder_color_dark="*neutral_500",
325
+ ),
326
+ css="""
327
+ footer {display: none !important;}
328
+ .video-fixed-height {
329
+ height: 600px !important;
330
+ }
331
+ .video-fixed-height video {
332
+ max-height: 600px !important;
333
+ object-fit: contain !important;
334
+ }
335
+ """) as demo:
336
+
337
+ gr.Markdown("# 🎭 Text-to-Motion Generator")
338
+ gr.Markdown("Generate 3D human motion animations from text descriptions")
339
+
340
+ with gr.Row():
341
+ with gr.Column():
342
+ text_input = gr.Textbox(
343
+ label="Describe the motion you want to generate",
344
+ placeholder="e.g., 'a person walks forward and waves'",
345
+ lines=3
346
+ )
347
+ submit_btn = gr.Button("Generate Motion", variant="primary")
348
+
349
+ gr.Examples(
350
+ examples=[
351
+ ["a person walks forward"],
352
+ ["a person jumps in place"],
353
+ ["someone performs a dance move"],
354
+ ["a person sits down on a chair"],
355
+ ["a person runs and then stops"],
356
+ ],
357
+ inputs=text_input,
358
+ label="Try these examples"
359
+ )
360
+
361
+ with gr.Column():
362
+ video_output = gr.Video(label="Generated Motion", elem_classes="video-fixed-height")
363
+
364
+ submit_btn.click(
365
+ fn=generate_motion,
366
+ inputs=text_input,
367
+ outputs=video_output
368
+ )
369
+
370
+ return demo
371
+
372
+
373
+ if __name__ == '__main__':
374
+ # Configuration
375
+ CHECKPOINTS_DIR = './checkpoints'
376
+ DATASET_NAME = 't2m' # or 'kit'
377
+ MODEL_NAME = 't2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns'
378
+ RES_NAME = 'rvq_nq6_dc512_nc512_noshare_qdp0.2'
379
+ VQ_NAME = 'rvq_nq6_dc512_nc512_noshare_qdp0.2'
380
+
381
+ # Initialize generator
382
+ generator = MotionGenerator(
383
+ checkpoints_dir=CHECKPOINTS_DIR,
384
+ dataset_name=DATASET_NAME,
385
+ model_name=MODEL_NAME,
386
+ res_name=RES_NAME,
387
+ vq_name=VQ_NAME,
388
+ device='cuda'
389
+ )
390
+
391
+ # Create and launch Gradio interface
392
+ demo = create_gradio_interface(generator)
393
  demo.launch(share=True, server_name="0.0.0.0", server_port=7860)