from transformers import AutoModelForCausalLM, AutoTokenizer import gradio as gr import json import torch import re import random class TableDataGenerator: def __init__(self, model_name="Qwen/Qwen2.5-3B-Instruct"): self.model_name = model_name self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto" ) self.tokenizer = AutoTokenizer.from_pretrained(model_name) def generate_batch_data(self, llm_commands, num_rows=1000, batch_size=30): """Generate table data in batches for better performance""" all_rows = [] # Create column headers description columns_desc = ", ".join([f"Column {i+1}: {cmd}" for i, cmd in enumerate(llm_commands)]) # Calculate number of batches num_batches = (num_rows + batch_size - 1) // batch_size for batch_idx in range(num_batches): current_batch_size = min(batch_size, num_rows - len(all_rows)) # Try multiple attempts to get enough rows for this batch batch_rows = [] max_attempts = 5 for attempt in range(max_attempts): remaining_needed = current_batch_size - len(batch_rows) if remaining_needed <= 0: break # Create prompt for this batch prompt = f"""Generate exactly {remaining_needed} rows of data for a table with columns: {columns_desc} Format: [['value1', 'value2'], ['value3', 'value4']] Requirements: - Each row must be different and realistic - Return ONLY the list, no explanations - Make data diverse and creative - Seed: {batch_idx * 10 + attempt} Generate exactly {remaining_needed} rows:""" messages = [ {"role": "system", "content": "You are a precise data generator. Return only valid Python list format with exactly the requested number of rows."}, {"role": "user", "content": prompt} ] # Generate response response = self._generate_response(messages, batch_idx * 10 + attempt) # Parse the response to extract rows new_rows = self._parse_response(response, len(llm_commands)) # Add unique rows only for row in new_rows: if row not in batch_rows and row not in all_rows: batch_rows.append(row) if len(batch_rows) >= current_batch_size: break # Add to all rows all_rows.extend(batch_rows) # If we still don't have enough, generate fallback data if len(all_rows) < num_rows and len(batch_rows) < current_batch_size: fallback_rows = self._generate_fallback_data(llm_commands, current_batch_size - len(batch_rows), len(all_rows)) all_rows.extend(fallback_rows) # Break if we have enough rows if len(all_rows) >= num_rows: break return all_rows[:num_rows] def _generate_response(self, messages, seed=None): """Generate response from the model""" text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) # Set random seed for variety if seed is not None: torch.manual_seed(seed) else: torch.manual_seed(random.randint(1, 10000)) generated_ids = self.model.generate( **model_inputs, max_new_tokens=300, temperature=0.9, do_sample=True, top_p=0.95, repetition_penalty=1.1 ) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return response def _parse_response(self, response, expected_columns): """Parse the model response to extract table rows""" rows = [] try: # Try to find list-like patterns in the response # Look for patterns like [['value1', 'value2'], ['value3', 'value4']] list_pattern = r'\[\s*\[.*?\]\s*\]' matches = re.findall(list_pattern, response, re.DOTALL) if matches: # Try to evaluate the largest match largest_match = max(matches, key=len) try: parsed_data = eval(largest_match) if isinstance(parsed_data, list): for row in parsed_data: if isinstance(row, list) and len(row) == expected_columns: rows.append([str(item) for item in row]) except: pass # If no valid list found, try to extract individual rows if not rows: # Look for individual row patterns like ['value1', 'value2'] row_pattern = r'\[([^\[\]]+)\]' row_matches = re.findall(row_pattern, response) for match in row_matches: try: # Split by comma and clean up items = [item.strip().strip('"\'') for item in match.split(',')] if len(items) == expected_columns: rows.append(items) except: continue except Exception as e: print(f"Error parsing response: {e}") return rows def _generate_fallback_data(self, llm_commands, needed_rows, current_count): """Generate fallback data when LLM doesn't produce enough rows""" fallback_rows = [] # Simple fallback generators based on command type for i in range(needed_rows): row = [] for cmd in llm_commands: cmd_lower = cmd.lower() if 'age' in cmd_lower: if 'between' in cmd_lower and '1' in cmd_lower and '20' in cmd_lower: row.append(str(random.randint(1, 20))) else: row.append(str(random.randint(18, 65))) elif 'arabic' in cmd_lower and 'name' in cmd_lower: arabic_names = ['محمد', 'أحمد', 'عبدالله', 'خالد', 'سعد', 'فهد', 'عبدالعزيز', 'ناصر', 'سلطان', 'طلال', 'فاطمة', 'عائشة', 'خديجة', 'مريم', 'زينب', 'سارة', 'نورا', 'هند', 'لطيفة', 'منى'] row.append(random.choice(arabic_names)) elif 'name' in cmd_lower: names = ['John', 'Jane', 'Michael', 'Sarah', 'David', 'Lisa', 'Robert', 'Emily', 'James', 'Jessica'] row.append(random.choice(names)) elif 'price' in cmd_lower or 'cost' in cmd_lower: row.append(str(random.randint(10, 1000))) elif 'city' in cmd_lower: cities = ['New York', 'London', 'Tokyo', 'Paris', 'Sydney', 'Cairo', 'Dubai', 'Berlin', 'Rome', 'Madrid'] row.append(random.choice(cities)) else: row.append(f"data_{current_count + i + 1}") fallback_rows.append(row) return fallback_rows def generate_table_data(json_input, num_rows=1000): """Main function to generate table data from JSON input""" try: # Parse JSON input data = json.loads(json_input) llm_commands = data.get('llm_commands', []) if not llm_commands: return "Error: No llm_commands found in JSON input", [] # Initialize generator generator = TableDataGenerator() # Generate data rows = generator.generate_batch_data(llm_commands, num_rows) # Create JSON structure json_data = { "columns": llm_commands, "rows": rows, "total_rows": len(rows) } # Save to JSON file with proper Arabic encoding import os os.makedirs('./train', exist_ok=True) json.dump(json_data, open('./train/data.json', "w", encoding="utf-8"), ensure_ascii=False, indent=2) # Format output result = f"Generated {len(rows)} rows (requested: {num_rows}):\n" result += f"Columns: {llm_commands}\n" result += f"Saved to: ./train/data.json\n\n" # Show first 10 rows as preview result += "First 10 rows:\n" for i, row in enumerate(rows[:10]): result += f"{i+1}: {row}\n" if len(rows) > 10: result += f"\n... and {len(rows) - 10} more rows" return result, json_data except json.JSONDecodeError: return "Error: Invalid JSON format", {} except Exception as e: return f"Error: {str(e)}", {} # Gradio Interface def process_json_input(json_input, num_rows): """Process JSON input and return formatted results""" result_text, json_data = generate_table_data(json_input, int(num_rows)) # Return JSON content for download if json_data and 'rows' in json_data: json_content = json.dumps(json_data, ensure_ascii=False, indent=2) return result_text, json_content else: return result_text, "" # Create Gradio interface with gr.Blocks(title="Table Data Generator") as demo: gr.Markdown("# Table Data Generator using LLM") gr.Markdown("Generate realistic table data based on column descriptions") with gr.Row(): with gr.Column(): json_input = gr.Textbox( label="JSON Input", placeholder='{"llm_commands": ["ages between 1 to 20", "arabic name"]}', lines=3, value='{"llm_commands": ["ages between 1 to 20", "arabic name"]}' ) num_rows = gr.Slider( minimum=10, maximum=2000, value=100, step=10, label="Number of rows to generate" ) generate_btn = gr.Button("Generate Data", variant="primary") with gr.Column(): output_text = gr.Textbox( label="Generated Data Preview", lines=15, max_lines=20 ) download_json = gr.File( label="Download JSON", visible=True ) def generate_and_save(json_input, num_rows): result_text, json_content = process_json_input(json_input, num_rows) if json_content: # Save to temporary file import tempfile import os with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False, encoding='utf-8') as f: f.write(json_content) temp_path = f.name return result_text, temp_path else: return result_text, None generate_btn.click( fn=generate_and_save, inputs=[json_input, num_rows], outputs=[output_text, download_json] ) # Example inputs gr.Examples( examples=[ ['{"llm_commands": ["ages between 1 to 20", "arabic name"]}', 50], ['{"llm_commands": ["random city", "population number", "country"]}', 100], ['{"llm_commands": ["product name", "price in USD", "category"]}', 75], ['{"llm_commands": ["email address", "phone number", "job title"]}', 60] ], inputs=[json_input, num_rows] ) if __name__ == "__main__": demo.launch() # Example usage: # json_input = '{"llm_commands": ["ages between 1 to 20", "arabic name"]}' # result_text, json_data = generate_table_data(json_input, 1000) # print(result_text) # print(f"Actual rows generated: {len(json_data.get('rows', []))}")