| import base64 |
| from io import BytesIO |
|
|
|
|
| def png_to_base64(image_file): |
| |
| from PIL import Image |
|
|
| EXTENSIONS = {'.png': 'PNG', '.apng': 'PNG', '.blp': 'BLP', '.bmp': 'BMP', '.dib': 'DIB', '.bufr': 'BUFR', |
| '.cur': 'CUR', '.pcx': 'PCX', '.dcx': 'DCX', '.dds': 'DDS', '.ps': 'EPS', '.eps': 'EPS', |
| '.fit': 'FITS', '.fits': 'FITS', '.fli': 'FLI', '.flc': 'FLI', '.fpx': 'FPX', '.ftc': 'FTEX', |
| '.ftu': 'FTEX', '.gbr': 'GBR', '.gif': 'GIF', '.grib': 'GRIB', '.h5': 'HDF5', '.hdf': 'HDF5', |
| '.jp2': 'JPEG2000', '.j2k': 'JPEG2000', '.jpc': 'JPEG2000', '.jpf': 'JPEG2000', '.jpx': 'JPEG2000', |
| '.j2c': 'JPEG2000', '.icns': 'ICNS', '.ico': 'ICO', '.im': 'IM', '.iim': 'IPTC', '.jfif': 'JPEG', |
| '.jpe': 'JPEG', '.jpg': 'JPEG', '.jpeg': 'JPEG', '.tif': 'TIFF', '.tiff': 'TIFF', '.mic': 'MIC', |
| '.mpg': 'MPEG', '.mpeg': 'MPEG', '.mpo': 'MPO', '.msp': 'MSP', '.palm': 'PALM', '.pcd': 'PCD', |
| '.pdf': 'PDF', '.pxr': 'PIXAR', '.pbm': 'PPM', '.pgm': 'PPM', '.ppm': 'PPM', '.pnm': 'PPM', |
| '.psd': 'PSD', '.qoi': 'QOI', '.bw': 'SGI', '.rgb': 'SGI', '.rgba': 'SGI', '.sgi': 'SGI', |
| '.ras': 'SUN', '.tga': 'TGA', '.icb': 'TGA', '.vda': 'TGA', '.vst': 'TGA', '.webp': 'WEBP', |
| '.wmf': 'WMF', '.emf': 'WMF', '.xbm': 'XBM', '.xpm': 'XPM'} |
|
|
| from pathlib import Path |
| ext = Path(image_file).suffix |
| if ext in EXTENSIONS: |
| iformat = EXTENSIONS[ext] |
| else: |
| raise ValueError("Invalid file extension %s for file %s" % (ext, image_file)) |
|
|
| image = Image.open(image_file) |
| buffered = BytesIO() |
| image.save(buffered, format=iformat) |
| img_str = base64.b64encode(buffered.getvalue()) |
| |
| img_str = str(bytes("data:image/%s;base64," % iformat.lower(), encoding='utf-8') + img_str) |
|
|
| return img_str |
|
|
|
|
| def get_llava_response(file, llava_model, |
| prompt=None, |
| image_model='llava-v1.5-13b', temperature=0.2, |
| top_p=0.7, max_new_tokens=512): |
| if prompt in ['auto', None]: |
| prompt = "Describe the image and what does the image say?" |
| |
|
|
| prefix = '' |
| if llava_model.startswith('http://'): |
| prefix = 'http://' |
| if llava_model.startswith('https://'): |
| prefix = 'https://' |
| llava_model = llava_model[len(prefix):] |
|
|
| llava_model_split = llava_model.split(':') |
| assert len(llava_model_split) >= 2 |
| |
| if len(llava_model_split) >= 2: |
| pass |
| |
| |
| |
| if len(llava_model_split) >= 3: |
| image_model = llava_model_split[2] |
| llava_model = ':'.join(llava_model_split[:2]) |
| |
| llava_model = prefix + llava_model |
|
|
| img_str = png_to_base64(file) |
|
|
| from gradio_client import Client |
| client = Client(llava_model, serialize=False) |
| load_res = client.predict(api_name='/demo_load') |
| model_options = [x[1] for x in load_res['choices']] |
| assert len(model_options), "LLaVa endpoint has no models: %s" % str(load_res) |
|
|
| |
| if not image_model or image_model not in model_options: |
| image_model = model_options[0] |
|
|
| |
|
|
| image_process_mode = "Default" |
| include_image = False |
| res1 = client.predict(prompt, img_str, image_process_mode, include_image, api_name='/textbox_api_btn') |
|
|
| model_selector, temperature, top_p, max_output_tokens = image_model, temperature, top_p, max_new_tokens |
| res = client.predict(model_selector, temperature, top_p, max_output_tokens, include_image, |
| api_name='/textbox_api_submit') |
| res = res[-1][-1] |
| return res, prompt |
|
|