| |
| |
|
|
| import copy |
| import os |
| import pickle |
| import shutil |
| import socket |
| import subprocess |
| import sys |
| import tarfile |
| import tempfile |
| import unittest |
| from collections import OrderedDict |
| from collections.abc import Mapping |
| from os.path import expanduser |
|
|
| import numpy as np |
| import requests |
| from modelscope.hub.constants import DEFAULT_CREDENTIALS_PATH |
|
|
| TEST_LEVEL = 2 |
| TEST_LEVEL_STR = 'TEST_LEVEL' |
|
|
| |
| TEST_ACCESS_TOKEN1 = os.environ.get('TEST_ACCESS_TOKEN_CITEST', None) |
| TEST_ACCESS_TOKEN2 = os.environ.get('TEST_ACCESS_TOKEN_SDKDEV', None) |
|
|
| TEST_MODEL_CHINESE_NAME = '内部测试模型' |
| TEST_MODEL_ORG = 'citest' |
|
|
|
|
| def delete_credential(): |
| path_credential = expanduser(DEFAULT_CREDENTIALS_PATH) |
| shutil.rmtree(path_credential, ignore_errors=True) |
|
|
|
|
| def test_level(): |
| global TEST_LEVEL |
| if TEST_LEVEL_STR in os.environ: |
| TEST_LEVEL = int(os.environ[TEST_LEVEL_STR]) |
|
|
| return TEST_LEVEL |
|
|
|
|
| def require_tf(test_case): |
| test_case = unittest.skip('test requires TensorFlow')(test_case) |
| return test_case |
|
|
|
|
| def require_torch(test_case): |
| return test_case |
|
|
|
|
| def set_test_level(level: int): |
| global TEST_LEVEL |
| TEST_LEVEL = level |
|
|
|
|
| class DummyTorchDataset: |
|
|
| def __init__(self, feat, label, num) -> None: |
| self.feat = feat |
| self.label = label |
| self.num = num |
|
|
| def __getitem__(self, index): |
| import torch |
| return {'feat': torch.Tensor(self.feat), 'labels': torch.Tensor(self.label)} |
|
|
| def __len__(self): |
| return self.num |
|
|
|
|
| def create_dummy_test_dataset(feat, label, num): |
| return DummyTorchDataset(feat, label, num) |
|
|
|
|
| def download_and_untar(fpath, furl, dst) -> str: |
| if not os.path.exists(fpath): |
| r = requests.get(furl) |
| with open(fpath, 'wb') as f: |
| f.write(r.content) |
|
|
| file_name = os.path.basename(fpath) |
| root_dir = os.path.dirname(fpath) |
| target_dir_name = os.path.splitext(os.path.splitext(file_name)[0])[0] |
| target_dir_path = os.path.join(root_dir, target_dir_name) |
|
|
| |
| t = tarfile.open(fpath) |
| t.extractall(path=dst) |
|
|
| return target_dir_path |
|
|
|
|
| def get_case_model_info(): |
| status_code, result = subprocess.getstatusoutput( |
| 'grep -rn "damo/" tests/ | grep -v ".pyc" | grep -v "Binary file" | grep -v run.py ') |
| lines = result.split('\n') |
| test_cases = OrderedDict() |
| model_cases = OrderedDict() |
| for line in lines: |
| |
| line = line.strip() |
| elements = line.split(':') |
| test_file = elements[0] |
| model_pos = line.find('damo') |
| left_quote = line[model_pos - 1] |
| rquote_idx = line.rfind(left_quote) |
| model_name = line[model_pos:rquote_idx] |
| if test_file not in test_cases: |
| test_cases[test_file] = set() |
| model_info = test_cases[test_file] |
| model_info.add(model_name) |
|
|
| if model_name not in model_cases: |
| model_cases[model_name] = set() |
| case_info = model_cases[model_name] |
| case_info.add(test_file.replace('tests/', '').replace('.py', '').replace('/', '.')) |
|
|
| return model_cases |
|
|
|
|
| def compare_arguments_nested(print_content, arg1, arg2, rtol=1.e-3, atol=1.e-8, ignore_unknown_type=True): |
| type1 = type(arg1) |
| type2 = type(arg2) |
| if type1.__name__ != type2.__name__: |
| if print_content is not None: |
| print(f'{print_content}, type not equal:{type1.__name__} and {type2.__name__}') |
| return False |
|
|
| if arg1 is None: |
| return True |
| elif isinstance(arg1, (int, str, bool, np.bool_, np.integer, np.str_)): |
| if arg1 != arg2: |
| if print_content is not None: |
| print(f'{print_content}, arg1:{arg1}, arg2:{arg2}') |
| return False |
| return True |
| elif isinstance(arg1, (float, np.floating)): |
| if not np.isclose(arg1, arg2, rtol=rtol, atol=atol, equal_nan=True): |
| if print_content is not None: |
| print(f'{print_content}, arg1:{arg1}, arg2:{arg2}') |
| return False |
| return True |
| elif isinstance(arg1, (tuple, list)): |
| if len(arg1) != len(arg2): |
| if print_content is not None: |
| print(f'{print_content}, length is not equal:{len(arg1)}, {len(arg2)}') |
| return False |
| if not all([ |
| compare_arguments_nested(None, sub_arg1, sub_arg2, rtol=rtol, atol=atol) |
| for sub_arg1, sub_arg2 in zip(arg1, arg2) |
| ]): |
| if print_content is not None: |
| print(f'{print_content}') |
| return False |
| return True |
| elif isinstance(arg1, Mapping): |
| keys1 = arg1.keys() |
| keys2 = arg2.keys() |
| if len(keys1) != len(keys2): |
| if print_content is not None: |
| print(f'{print_content}, key length is not equal:{len(keys1)}, {len(keys2)}') |
| return False |
| if len(set(keys1) - set(keys2)) > 0: |
| if print_content is not None: |
| print(f'{print_content}, key diff:{set(keys1) - set(keys2)}') |
| return False |
| if not all([compare_arguments_nested(None, arg1[key], arg2[key], rtol=rtol, atol=atol) for key in keys1]): |
| if print_content is not None: |
| print(f'{print_content}') |
| return False |
| return True |
| elif isinstance(arg1, np.ndarray): |
| arg1 = np.where(np.equal(arg1, None), np.NaN, arg1).astype(dtype=float) |
| arg2 = np.where(np.equal(arg2, None), np.NaN, arg2).astype(dtype=float) |
| if not all(np.isclose(arg1, arg2, rtol=rtol, atol=atol, equal_nan=True).flatten()): |
| if print_content is not None: |
| print(f'{print_content}') |
| return False |
| return True |
| else: |
| if ignore_unknown_type: |
| return True |
| else: |
| raise ValueError(f'type not supported: {type1}') |
|
|
|
|
| _DIST_SCRIPT_TEMPLATE = """ |
| import ast |
| import argparse |
| import pickle |
| import torch |
| from torch import distributed as dist |
| from modelscope.utils.torch_utils import get_dist_info |
| import {} |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--save_all_ranks', type=ast.literal_eval, help='save all ranks results') |
| parser.add_argument('--save_file', type=str, help='save file') |
| parser.add_argument('--local_rank', type=int, default=0) |
| args = parser.parse_args() |
| |
| |
| def main(): |
| results = {}.{}({}) # module.func(params) |
| if args.save_all_ranks: |
| save_file = args.save_file + str(dist.get_rank()) |
| with open(save_file, 'wb') as f: |
| pickle.dump(results, f) |
| else: |
| rank, _ = get_dist_info() |
| if rank == 0: |
| with open(args.save_file, 'wb') as f: |
| pickle.dump(results, f) |
| |
| |
| if __name__ == '__main__': |
| main() |
| """ |
|
|
|
|
| class DistributedTestCase(unittest.TestCase): |
| """Distributed TestCase for test function with distributed mode. |
| Examples: |
| >>> import torch |
| >>> from torch import distributed as dist |
| >>> from modelscope.utils.torch_utils import init_dist |
| |
| >>> def _test_func(*args, **kwargs): |
| >>> init_dist(launcher='pytorch') |
| >>> rank = dist.get_rank() |
| >>> if rank == 0: |
| >>> value = torch.tensor(1.0).cuda() |
| >>> else: |
| >>> value = torch.tensor(2.0).cuda() |
| >>> dist.all_reduce(value) |
| >>> return value.cpu().numpy() |
| |
| >>> class DistTest(DistributedTestCase): |
| >>> def test_function_dist(self): |
| >>> args = () # args should be python builtin type |
| >>> kwargs = {} # kwargs should be python builtin type |
| >>> self.start( |
| >>> _test_func, |
| >>> num_gpus=2, |
| >>> assert_callback=lambda x: self.assertEqual(x, 3.0), |
| >>> *args, |
| >>> **kwargs, |
| >>> ) |
| """ |
|
|
| def _start(self, dist_start_cmd, func, num_gpus, assert_callback=None, save_all_ranks=False, *args, **kwargs): |
| script_path = func.__code__.co_filename |
| script_dir, script_name = os.path.split(script_path) |
| script_name = os.path.splitext(script_name)[0] |
| func_name = func.__qualname__ |
|
|
| func_params = [] |
| for arg in args: |
| if isinstance(arg, str): |
| arg = ('\'{}\''.format(arg)) |
| func_params.append(str(arg)) |
|
|
| for k, v in kwargs.items(): |
| if isinstance(v, str): |
| v = ('\'{}\''.format(v)) |
| func_params.append('{}={}'.format(k, v)) |
|
|
| func_params = ','.join(func_params).strip(',') |
|
|
| tmp_run_file = tempfile.NamedTemporaryFile(suffix='.py').name |
| tmp_res_file = tempfile.NamedTemporaryFile(suffix='.pkl').name |
|
|
| with open(tmp_run_file, 'w') as f: |
| print('save temporary run file to : {}'.format(tmp_run_file)) |
| print('save results to : {}'.format(tmp_res_file)) |
| run_file_content = _DIST_SCRIPT_TEMPLATE.format(script_name, script_name, func_name, func_params) |
| f.write(run_file_content) |
|
|
| tmp_res_files = [] |
| if save_all_ranks: |
| for i in range(num_gpus): |
| tmp_res_files.append(tmp_res_file + str(i)) |
| else: |
| tmp_res_files = [tmp_res_file] |
| self.addCleanup(self.clean_tmp, [tmp_run_file] + tmp_res_files) |
|
|
| tmp_env = copy.deepcopy(os.environ) |
| tmp_env['PYTHONPATH'] = ':'.join((tmp_env.get('PYTHONPATH', ''), script_dir)).lstrip(':') |
| |
| tmp_env['NCCL_P2P_DISABLE'] = '1' |
| script_params = '--save_all_ranks=%s --save_file=%s' % (save_all_ranks, tmp_res_file) |
| script_cmd = '%s %s %s' % (dist_start_cmd, tmp_run_file, script_params) |
| print('script command: %s' % script_cmd) |
| res = subprocess.call(script_cmd, shell=True, env=tmp_env) |
|
|
| script_res = [] |
| for res_file in tmp_res_files: |
| with open(res_file, 'rb') as f: |
| script_res.append(pickle.load(f)) |
| if not save_all_ranks: |
| script_res = script_res[0] |
|
|
| if assert_callback: |
| assert_callback(script_res) |
|
|
| self.assertEqual(res, 0, msg='The test function ``{}`` in ``{}`` run failed!'.format(func_name, script_name)) |
|
|
| return script_res |
|
|
| def start(self, func, num_gpus, assert_callback=None, save_all_ranks=False, *args, **kwargs): |
| from .torch_utils import _find_free_port |
| ip = socket.gethostbyname(socket.gethostname()) |
| if 'dist_start_cmd' in kwargs: |
| dist_start_cmd = kwargs.pop('dist_start_cmd') |
| else: |
| dist_start_cmd = '%s -m torch.distributed.launch --nproc_per_node=%d ' \ |
| '--master_addr=\'%s\' --master_port=%s' % (sys.executable, num_gpus, ip, _find_free_port()) |
|
|
| return self._start( |
| dist_start_cmd=dist_start_cmd, |
| func=func, |
| num_gpus=num_gpus, |
| assert_callback=assert_callback, |
| save_all_ranks=save_all_ranks, |
| *args, |
| **kwargs) |
|
|
| def clean_tmp(self, tmp_file_list): |
| for file in tmp_file_list: |
| if os.path.exists(file): |
| if os.path.isdir(file): |
| shutil.rmtree(file) |
| else: |
| os.remove(file) |
|
|