| import numpy as np |
|
|
|
|
| def init(cfg): |
| chat_template = cfg['chat_template'] |
| model = cfg['model'] |
| s_info = cfg['s_info'] |
| lock = cfg['session_lock'] |
|
|
| |
| def str_tokenize(s): |
| s = model.tokenize((chat_template.nl + s).encode('utf-8'), add_bos=False, special=False) |
| if s[0] in chat_template.onenl: |
| return s[1:] |
| else: |
| return s |
|
|
| text_format = cfg['text_format'] |
| for x in cfg['btn_status_bar_list']: |
| x['key'] = text_format(x['key'], |
| char=cfg['role_char'].value, |
| user=cfg['role_usr'].value) |
| x['key_t'] = str_tokenize(x['key']) |
| x['desc'] = text_format(x['desc'], |
| char=cfg['role_char'].value, |
| user=cfg['role_usr'].value) |
| if x['desc']: |
| x['desc_t'] = str_tokenize(x['desc']) |
|
|
| |
| def btn_status_bar_fn_mask(): |
| _shape1d = model.scores.shape[-1] |
| mask = np.full((_shape1d,), -np.inf, dtype=np.single) |
| return mask |
|
|
| |
| def btn_status_bar_fn_int(unit: str): |
| t_int = str_tokenize('0123456789') |
| assert len(t_int) == 10 |
| fn_int_mask = btn_status_bar_fn_mask() |
| fn_int_mask[chat_template.eos] = 0 |
| fn_int_mask[t_int] = 0 |
| if unit: |
| unit_t = str_tokenize(unit) |
| fn_int_mask[unit_t[0]] = 0 |
|
|
| def logits_processor(_input_ids, logits): |
| return logits + fn_int_mask |
|
|
| def inner(eval_t, sample_t): |
| retn = [] |
| while True: |
| token = sample_t(logits_processor) |
| |
| if token in chat_template.eos: |
| break |
| if unit and token == unit_t[0]: |
| break |
| |
| retn.append(token) |
| eval_t([token]) |
|
|
| if unit: |
| eval_t(unit_t) |
| retn.extend(unit_t) |
|
|
| return model.str_detokenize(retn) |
|
|
| return inner |
|
|
| |
| def btn_status_bar_fn_set(value): |
| value_t = {_x[0][0]: _x for _x in ((str_tokenize(_y), _y) for _y in value)} |
| fn_set_mask = btn_status_bar_fn_mask() |
| fn_set_mask[list(value_t.keys())] = 0 |
|
|
| def logits_processor(_input_ids, logits): |
| return logits + fn_set_mask |
|
|
| def inner(eval_t, sample_t): |
| token = sample_t(logits_processor) |
| eval_t(value_t[token][0]) |
| return value_t[token][1] |
|
|
| return inner |
|
|
| |
| def btn_status_bar_fn_str(): |
| def inner(eval_t, sample_t): |
| retn = [] |
| tmp = '' |
| while True: |
| token = sample_t(None) |
| if token in chat_template.eos: |
| break |
| retn.append(token) |
| tmp = model.str_detokenize(retn) |
| if tmp.endswith('\n') or tmp.endswith('\r'): |
| break |
| |
| eval_t([token]) |
| return tmp.strip() |
|
|
| return inner |
|
|
| |
| for x in cfg['btn_status_bar_list']: |
| for y in x['combine']: |
| if y['prefix']: |
| y['prefix_t'] = str_tokenize(y['prefix']) |
|
|
| if y['type'] == 'int': |
| y['fn'] = btn_status_bar_fn_int(y['unit']) |
| elif y['type'] == 'set': |
| y['fn'] = btn_status_bar_fn_set(y['value']) |
| elif y['type'] == 'str': |
| y['fn'] = btn_status_bar_fn_str() |
| else: |
| pass |
|
|
| |
| for i, x in enumerate(cfg['btn_status_bar_list']): |
| if i == 0: |
| continue |
| x['key_t'] = chat_template.im_end_nl[-1:] + x['key_t'] |
|
|
| del x |
| del y |
|
|
| |
|
|
| |
| def btn_status_bar(_n_keep, _n_discard, |
| _temperature, _repeat_penalty, _frequency_penalty, |
| _presence_penalty, _repeat_last_n, _top_k, |
| _top_p, _min_p, _typical_p, |
| _tfs_z, _mirostat_mode, _mirostat_eta, |
| _mirostat_tau, _usr, _char, |
| _rag, _max_tokens): |
| with lock: |
| if not cfg['session_active']: |
| raise RuntimeError |
| if cfg['btn_stop_status']: |
| yield [], model.venv_info |
| return |
|
|
| |
| def eval_t(tokens): |
| return model.eval_t( |
| tokens=tokens, |
| n_keep=_n_keep, |
| n_discard=_n_discard, |
| im_start=chat_template.im_start_token |
| ) |
|
|
| def sample_t(logits_processor): |
| return model.sample_t( |
| top_k=_top_k, |
| top_p=_top_p, |
| min_p=_min_p, |
| typical_p=_typical_p, |
| temp=_temperature, |
| repeat_penalty=_repeat_penalty, |
| repeat_last_n=_repeat_last_n, |
| frequency_penalty=_frequency_penalty, |
| presence_penalty=_presence_penalty, |
| tfs_z=_tfs_z, |
| mirostat_mode=_mirostat_mode, |
| mirostat_tau=_mirostat_tau, |
| mirostat_eta=_mirostat_eta, |
| logits_processor=logits_processor |
| ) |
|
|
| |
| model.venv_create('status') |
| eval_t(chat_template('状态')) |
| |
| df = [] |
| for _x in cfg['btn_status_bar_list']: |
| |
| df.append([_x['key'], '']) |
| eval_t(_x['key_t']) |
| if _x['desc']: |
| eval_t(_x['desc_t']) |
| yield df, model.venv_info |
| |
| for _y in _x['combine']: |
| if _y['prefix']: |
| if df[-1][-1]: |
| df[-1][-1] += _y['prefix'] |
| else: |
| df[-1][-1] += _y['prefix'].lstrip(':') |
| eval_t(_y['prefix_t']) |
| df[-1][-1] += _y['fn'](eval_t, sample_t) |
| yield df, model.venv_info |
| eval_t(chat_template.im_end_nl) |
| |
| model.venv_remove('status', keep_last=1) |
| yield df, model.venv_info |
|
|
| cfg['btn_status_bar_fn'] = { |
| 'fn': btn_status_bar, |
| 'inputs': cfg['setting'], |
| 'outputs': [cfg['status_bar'], s_info] |
| } |
| cfg['btn_status_bar_fn'].update(cfg['btn_concurrency']) |
|
|
| cfg['btn_status_bar'].click( |
| **cfg['btn_start'] |
| ).success( |
| **cfg['btn_status_bar_fn'] |
| ).success( |
| **cfg['btn_finish'] |
| ) |
|
|