File size: 8,168 Bytes
5075e56 dffe20b 5075e56 dffe20b 5075e56 dffe20b 5075e56 83943dc 5075e56 dffe20b 205b123 5075e56 dffe20b 5075e56 dffe20b 5075e56 83943dc 5075e56 83943dc 5075e56 f75b046 5075e56 83943dc 205b123 f75b046 5075e56 dffe20b 5075e56 dffe20b 5075e56 dffe20b 5075e56 f75b046 5075e56 f75b046 5075e56 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | import { useState } from 'react';
import { useTranslation } from 'next-i18next';
import ApiKeyInput from './ApiKeyInput';
import EvolutionaryParams from './EvolutionaryParams';
export default function JobForm() {
const { t } = useTranslation('common');
const [method, setMethod] = useState('linear');
const [mergeType, setMergeType] = useState('linear');
const [modelASource, setModelASource] = useState('hf');
const [modelBSource, setModelBSource] = useState('hf');
const [modelAId, setModelAId] = useState('');
const [modelBId, setModelBId] = useState('');
const [alpha, setAlpha] = useState(0.5);
const [outputRepo, setOutputRepo] = useState('');
const [datasetFile, setDatasetFile] = useState<File | null>(null);
const [evoParams, setEvoParams] = useState({});
const [frankenLayers, setFrankenLayers] = useState('');
const [submitting, setSubmitting] = useState(false);
const [errorMessage, setErrorMessage] = useState<string | null>(null);
// HuggingFace URL または Civitai URL から適切なIDを抽出する
const normalizeRepoId = (input: string, source: string): string => {
if (source === 'hf') {
// "https://huggingface.co/namespace/repo" または "namespace/repo" のいずれか
const match = input.match(/(?:huggingface\.co\/)?([^\/]+\/[^\/]+?)(?:\/resolve\/.*)?$/);
if (match) return match[1];
} else if (source === 'civitai') {
// Civitai モデルバージョンID(数字)を想定。URLなら抽出する
const match = input.match(/models\/\d+\?modelVersionId=(\d+)/) || input.match(/model-versions\/(\d+)/);
if (match) return match[1];
}
// 何もマッチしなければそのまま返す(バックエンド側でエラーになるかもしれないが、元の動作を維持)
return input;
};
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
setSubmitting(true);
setErrorMessage(null);
const hfToken = sessionStorage.getItem('hf_token_manual') || '';
const civitaiKey = sessionStorage.getItem('civitai_key') || '';
if (!hfToken) {
setErrorMessage(t('hf_token_required'));
setSubmitting(false);
return;
}
let datasetPath = '';
if (datasetFile) {
const formData = new FormData();
formData.append('file', datasetFile);
try {
const res = await fetch('/api/backend/upload-dataset', { method: 'POST', body: formData });
if (!res.ok) {
const err = await res.json();
throw new Error(err.detail || err.error || 'Dataset upload failed');
}
const json = await res.json();
datasetPath = json.path;
} catch (err: any) {
setErrorMessage(err.message);
setSubmitting(false);
return;
}
}
const payload = {
model_a_source: modelASource,
model_a_id: normalizeRepoId(modelAId, modelASource),
model_b_source: modelBSource,
model_b_id: normalizeRepoId(modelBId, modelBSource),
method,
merge_type: method === 'linear' ? mergeType : undefined,
linear_alpha: alpha,
output_repo_name: outputRepo,
dataset: datasetPath,
evo_params: method === 'evolutionary' ? evoParams : null,
hf_token_manual: hfToken,
civitai_key: civitaiKey,
franken_layers: mergeType === 'franken' ? frankenLayers.split(',').map(s => s.trim()) : undefined,
};
try {
const res = await fetch('/api/backend/submit-job', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload),
});
if (!res.ok) {
const data = await res.json();
throw new Error(data.error || data.detail || `Request failed with status ${res.status}`);
}
alert(t('job_submitted'));
setModelAId('');
setModelBId('');
setOutputRepo('');
setDatasetFile(null);
setFrankenLayers('');
} catch (err: any) {
setErrorMessage(err.message);
} finally {
setSubmitting(false);
}
};
return (
<form onSubmit={handleSubmit} className="space-y-4 bg-white p-6 shadow rounded-lg">
<ApiKeyInput />
{errorMessage && (
<div className="bg-red-50 border border-red-200 text-red-700 px-4 py-3 rounded">
{errorMessage}
</div>
)}
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
<div>
<label className="block text-sm font-medium">{t('model_a')}</label>
<select value={modelASource} onChange={e => setModelASource(e.target.value)} className="mt-1 block w-full border rounded p-2">
<option value="hf">HuggingFace</option>
<option value="civitai">Civitai</option>
</select>
<input type="text" placeholder={t('repo_id_or_url')} value={modelAId} onChange={e => setModelAId(e.target.value)} className="mt-1 block w-full border rounded p-2" required />
</div>
<div>
<label className="block text-sm font-medium">{t('model_b')}</label>
<select value={modelBSource} onChange={e => setModelBSource(e.target.value)} className="mt-1 block w-full border rounded p-2">
<option value="hf">HuggingFace</option>
<option value="civitai">Civitai</option>
</select>
<input type="text" placeholder={t('repo_id_or_url')} value={modelBId} onChange={e => setModelBId(e.target.value)} className="mt-1 block w-full border rounded p-2" required />
</div>
</div>
<div>
<label className="block text-sm font-medium">{t('method')}</label>
<select value={method} onChange={e => setMethod(e.target.value)} className="mt-1 block w-full border rounded p-2">
<option value="linear">{t('linear_merge_label')}</option>
<option value="evolutionary">{t('evolutionary')}</option>
</select>
</div>
{method === 'linear' && (
<>
<div>
<label className="block text-sm font-medium">{t('merge_type')}</label>
<select value={mergeType} onChange={e => setMergeType(e.target.value)} className="mt-1 block w-full border rounded p-2">
<option value="linear">{t('linear')}</option>
<option value="slerp">{t('slerp')}</option>
<option value="franken">{t('franken')}</option>
</select>
</div>
{mergeType !== 'franken' && (
<div>
<label className="block text-sm font-medium">{t('alpha')}: {alpha}</label>
<input type="range" min="0" max="1" step="0.01" value={alpha} onChange={e => setAlpha(parseFloat(e.target.value))} className="w-full" />
</div>
)}
{mergeType === 'franken' && (
<div>
<label className="block text-sm font-medium">{t('franken_layers_from_a')}</label>
<textarea
value={frankenLayers}
onChange={e => setFrankenLayers(e.target.value)}
className="mt-1 block w-full border rounded p-2 text-sm"
placeholder="layer.0.mlp.fc1, layer.0.mlp.fc2"
rows={3}
/>
<p className="text-xs text-gray-500">{t('franken_layers_help')}</p>
</div>
)}
</>
)}
{method === 'evolutionary' && (
<EvolutionaryParams onChange={setEvoParams} />
)}
<div>
<label className="block text-sm font-medium">{t('calibration_dataset')} ({t('optional')})</label>
<input type="file" onChange={e => setDatasetFile(e.target.files?.[0] || null)} />
</div>
<div>
<label className="block text-sm font-medium">{t('output_repo_name')}</label>
<input type="text" value={outputRepo} onChange={e => setOutputRepo(e.target.value)} placeholder="my-merged-model" className="mt-1 block w-full border rounded p-2" />
</div>
<button type="submit" disabled={submitting} className="bg-blue-600 text-white px-4 py-2 rounded hover:bg-blue-700 disabled:opacity-50">
{submitting ? t('submitting') : t('start_merge')}
</button>
</form>
);
} |