| |
| from fastapi import FastAPI |
| from pydantic import BaseModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| import torch |
| import os |
|
|
| |
| |
| model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
| device = torch.device("cpu") |
|
|
| |
| |
| |
| quantization_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
|
|
| |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| |
| |
| print("Tentative de chargement avec quantisation 4-bit...") |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| quantization_config=quantization_config, |
| device_map="auto", |
| trust_remote_code=True |
| ) |
| print(f"Modèle {model_id} chargé et quantifié.") |
|
|
| except Exception as e_quant: |
| |
| print(f"Échec de la quantisation : {e_quant}. Tentative de chargement float32 CPU (Attention: peut causer OOM).") |
| |
| |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=torch.float32, |
| trust_remote_code=True |
| ).to(device) |
| print(f"Modèle {model_id} chargé sur CPU (Float32).") |
| except Exception as e_cpu: |
| print(f"Échec critique du chargement CPU : {e_cpu}") |
| |
| raise e_cpu |
|
|
| model.eval() |
|
|
| app = FastAPI( |
| title="NLP Space - Phi-3 Mini API (CPU)", |
| description="API REST pour génération, résumé et classification de texte, optimisée pour CPU." |
| ) |
|
|
| |
| class PromptRequest(BaseModel): |
| prompt: str |
| max_tokens: int = 500 |
| temperature: float = 0.7 |
|
|
| |
| def generate_text_from_model(system_prompt: str, user_prompt: str, max_tokens: int, temperature: float): |
| |
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt} |
| ] |
| |
| text_to_generate = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) |
| |
| |
| |
| real_device = model.device if model.device.type != 'meta' else torch.device("cpu") |
| inputs = tokenizer(text_to_generate, return_tensors="pt").to(real_device) |
| |
| with torch.no_grad(): |
| output = model.generate( |
| **inputs, |
| max_new_tokens=max_tokens, |
| do_sample=True, |
| temperature=temperature, |
| pad_token_id=tokenizer.eos_token_id, |
| use_cache=False |
| ) |
| |
| generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
| |
| |
| response_start_tag = "<|assistant|>" |
| if response_start_tag in generated_text: |
| return generated_text.split(response_start_tag, 1)[1].strip() |
| |
| return generated_text.strip() |
|
|
|
|
| |
|
|
| @app.post("/generate") |
| async def generate(request: PromptRequest): |
| """Génération de texte libre.""" |
| system_prompt = "Tu es un assistant IA très utile et créatif." |
| try: |
| result = generate_text_from_model( |
| system_prompt=system_prompt, |
| user_prompt=request.prompt, |
| max_tokens=request.max_tokens, |
| temperature=request.temperature |
| ) |
| return {"result": result} |
| except Exception as e: |
| return {"error": str(e)} |
|
|
| @app.post("/summarize") |
| async def summarize(request: PromptRequest): |
| |
| system_prompt = "Tu es un expert en résumé concis et précis. Ton objectif est de résumer le texte fourni de manière à en conserver l'idée principale." |
| user_prompt = f"Résume le texte suivant de manière concise et factuelle:\n\n---\n\n{request.prompt}" |
| try: |
| result = generate_text_from_model( |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| max_tokens=request.max_tokens, |
| temperature=0.3 |
| ) |
| return {"result": result} |
| except Exception as e: |
| return {"error": str(e)} |
|
|
| @app.post("/classify") |
| async def classify(request: PromptRequest): |
| |
| system_prompt = "Tu es un expert en classification. Réponds uniquement avec l'étiquette de classification demandée sans phrases supplémentaires." |
| user_prompt = request.prompt |
| try: |
| result = generate_text_from_model( |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| max_tokens=50, |
| temperature=0.1 |
| ) |
| return {"result": result} |
| except Exception as e: |
| return {"error": str(e)} |