| import argparse |
| import copy |
| import json |
| import os.path as osp |
|
|
| import mmengine |
| from mmengine.config import Config, ConfigDict |
| from mmengine.utils import mkdir_or_exist |
| from tqdm import tqdm |
|
|
| from opencompass.registry import TEXT_POSTPROCESSORS |
| from opencompass.utils import build_dataset_from_cfg, get_infer_output_path |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Run case analyzer') |
| parser.add_argument('config', help='Train config file path') |
| parser.add_argument( |
| '-f', |
| '--force', |
| help='Force to run the task even if the results already exist', |
| action='store_true', |
| default=False) |
| parser.add_argument('-w', |
| '--work-dir', |
| help='Work path, all the outputs will be ' |
| 'saved in this path, including the slurm logs, ' |
| 'the evaluation results, the summary results, etc.' |
| 'If not specified, the work_dir will be set to ' |
| './outputs/default.', |
| default=None, |
| type=str) |
| args = parser.parse_args() |
| return args |
|
|
|
|
| class BadcaseShower: |
| """""" |
|
|
| def __init__(self, cfg: ConfigDict) -> None: |
|
|
| self.cfg = cfg |
| self.model_cfg = copy.deepcopy(self.cfg['model']) |
| self.dataset_cfg = copy.deepcopy(self.cfg['dataset']) |
| self.work_dir = self.cfg.get('work_dir') |
| |
| self.eval_cfg = self.dataset_cfg.get('eval_cfg') |
| self.ds_split = self.eval_cfg.get('ds_split', None) |
| self.ds_column = self.eval_cfg.get('ds_column') |
|
|
| def run(self): |
| filename = get_infer_output_path( |
| self.model_cfg, self.dataset_cfg, |
| osp.join(self.work_dir, 'predictions')) |
| root, ext = osp.splitext(filename) |
| partial_filename = root + '_0' + ext |
|
|
| if not osp.exists(osp.realpath(filename)) and not osp.exists( |
| osp.realpath(partial_filename)): |
| print(f'{filename} not found') |
| return |
|
|
| dataset = build_dataset_from_cfg(self.dataset_cfg) |
| |
| if 'dataset_postprocessor' in self.eval_cfg: |
|
|
| def postprocess(sample): |
| s = sample[self.ds_column] |
| proc = TEXT_POSTPROCESSORS.get( |
| self.eval_cfg['dataset_postprocessor']['type']) |
| sample[self.ds_column] = proc(s) |
| return sample |
|
|
| dataset = dataset.map(postprocess) |
|
|
| |
| if osp.exists(osp.realpath(filename)): |
| preds = mmengine.load(filename) |
| else: |
| filename = partial_filename |
| preds, offset = {}, 0 |
| i = 1 |
| while osp.exists(osp.realpath(filename)): |
| _preds = mmengine.load(filename) |
| filename = root + f'_{i}' + ext |
| i += 1 |
| for _o in range(len(_preds)): |
| preds[str(offset)] = _preds[str(_o)] |
| offset += 1 |
| pred_strs = [preds[str(i)]['prediction'] for i in range(len(preds))] |
|
|
| |
| if 'pred_postprocessor' in self.eval_cfg: |
| proc = TEXT_POSTPROCESSORS.get( |
| self.eval_cfg['pred_postprocessor']['type']) |
| pred_strs = [proc(s) for s in pred_strs] |
|
|
| if self.ds_split: |
| references = dataset[self.ds_split][self.ds_column] |
| else: |
| references = dataset[self.ds_column] |
|
|
| if len(pred_strs) != len(references): |
| print('length mismatch') |
| return |
|
|
| |
| allcase, badcase = [], [] |
| if 'in-context examples' in preds['0']: |
| |
| for i, (pred_str, |
| reference) in enumerate(zip(tqdm(pred_strs), references)): |
| ref_str = str(reference) |
| try: |
| pred_prompt = preds[str(i)]['label: ' + |
| pred_str]['testing input'] |
| pred_PPL = preds[str(i)]['label: ' + pred_str]['PPL'] |
| ref_prompt = preds[str(i)]['label: ' + |
| ref_str]['testing input'] |
| ref_PPL = preds[str(i)]['label: ' + ref_str]['PPL'] |
| except KeyError: |
| continue |
| item = { |
| 'prediction_prompt': pred_prompt, |
| 'prediction': pred_str, |
| 'prediction_PPL': pred_PPL, |
| 'reference_prompt': ref_prompt, |
| 'reference': ref_str, |
| 'reference_PPL': ref_PPL |
| } |
| if pred_str != ref_str: |
| badcase.append(item) |
| allcase.append(item) |
| else: |
| allcase.append(item) |
|
|
| else: |
| |
| for i, (pred_str, |
| reference) in enumerate(zip(tqdm(pred_strs), references)): |
| ref_str = str(reference) |
| origin_prompt = preds[str(i)]['origin_prompt'] |
| item = { |
| 'origin_prompt': origin_prompt, |
| 'prediction': pred_str, |
| 'reference': ref_str |
| } |
| |
| badcase.append(item) |
| allcase.append(item) |
|
|
| |
| out_path = get_infer_output_path( |
| self.cfg['model'], self.cfg['dataset'], |
| osp.join(self.work_dir, 'case_analysis/bad')) |
| mkdir_or_exist(osp.split(out_path)[0]) |
| with open(out_path, 'w', encoding='utf-8') as f: |
| json.dump(badcase, f, indent=4, ensure_ascii=False) |
|
|
| out_path = get_infer_output_path( |
| self.cfg['model'], self.cfg['dataset'], |
| osp.join(self.work_dir, 'case_analysis/all')) |
| mkdir_or_exist(osp.split(out_path)[0]) |
| with open(out_path, 'w', encoding='utf-8') as f: |
| json.dump(allcase, f, indent=4, ensure_ascii=False) |
|
|
|
|
| def dispatch_tasks(cfg, force=False): |
| for model in cfg['models']: |
| for dataset in cfg['datasets']: |
| if force or not osp.exists( |
| get_infer_output_path( |
| model, dataset, |
| osp.join(cfg['work_dir'], 'case_analysis/all'))): |
| BadcaseShower({ |
| 'model': model, |
| 'dataset': dataset, |
| 'work_dir': cfg['work_dir'] |
| }).run() |
|
|
|
|
| def main(): |
| args = parse_args() |
| cfg = Config.fromfile(args.config) |
| |
| if args.work_dir is not None: |
| cfg['work_dir'] = args.work_dir |
| else: |
| cfg.setdefault('work_dir', './outputs/default') |
| dispatch_tasks(cfg, force=args.force) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|