| |
| """ |
| 示例:如何使用修改后的eval_seg函数同时获取分割结果和文本输出 |
| """ |
|
|
| import torch |
| from transformers import AutoTokenizer |
|
|
| def example_eval_seg_with_text_output(model, tokenizer, input_data): |
| """ |
| 示例函数:展示如何使用修改后的eval_seg函数 |
| |
| Args: |
| model: PSALM模型实例 |
| tokenizer: 对应的tokenizer |
| input_data: 输入数据字典,包含以下键: |
| - input_ids: torch.LongTensor |
| - images: torch.FloatTensor |
| - seg_info: 分割信息 |
| - 其他必要参数... |
| |
| Returns: |
| dict: 包含分割结果和解码后的文本 |
| """ |
| |
| |
| result = model.eval_seg( |
| input_ids=input_data['input_ids'], |
| images=input_data['images'], |
| seg_info=input_data['seg_info'], |
| |
| generate_text=True, |
| max_new_tokens=512, |
| temperature=0.2, |
| do_sample=True, |
| |
| attention_mask=input_data.get('attention_mask'), |
| class_name_embedding_indices=input_data.get('class_name_embedding_indices'), |
| cls_indices=input_data.get('cls_indices'), |
| token_refer_id=input_data.get('token_refer_id'), |
| refer_embedding_indices=input_data.get('refer_embedding_indices'), |
| is_thing_list=input_data.get('is_thing_list') |
| ) |
| |
| |
| segmentation_results = result['segmentation_results'] |
| |
| |
| output_token_ids = result['output_token_ids'] |
| decoded_text = None |
| |
| if output_token_ids is not None: |
| |
| decoded_text = tokenizer.decode( |
| output_token_ids[0], |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=True |
| ) |
| print(f"生成的文本: {decoded_text}") |
| |
| |
| for i, seg_result in enumerate(segmentation_results): |
| print(f"图片 {i} 的分割结果:") |
| if 'instances' in seg_result: |
| instances = seg_result['instances'] |
| print(f" - 检测到 {len(instances.pred_masks)} 个实例") |
| print(f" - pred_masks shape: {instances.pred_masks.shape}") |
| print(f" - scores: {instances.scores}") |
| if hasattr(instances, 'pred_boxes'): |
| print(f" - pred_boxes: {instances.pred_boxes}") |
| |
| if 'sem_seg' in seg_result: |
| print(f" - 语义分割结果 shape: {seg_result['sem_seg'].shape}") |
| |
| if 'panoptic_seg' in seg_result: |
| print(f" - 全景分割结果") |
| |
| return { |
| 'segmentation_results': segmentation_results, |
| 'decoded_text': decoded_text, |
| 'raw_token_ids': output_token_ids |
| } |
|
|
| def example_usage(): |
| """ |
| 完整的使用示例 |
| """ |
| |
| |
| |
| |
| |
| input_data = { |
| 'input_ids': torch.tensor([[1, 2, 3, ...]]), |
| 'images': torch.randn(1, 3, 224, 224), |
| 'seg_info': [{'instances': ...}], |
| |
| } |
| |
| |
| |
| |
| print("示例代码准备完毕!") |
| print("使用时请确保:") |
| print("1. 已正确加载PSALM模型") |
| print("2. 已正确加载对应的tokenizer") |
| print("3. 准备了正确格式的input_data") |
|
|
| if __name__ == "__main__": |
| example_usage() |
|
|