Spaces:
Sleeping
Sleeping
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import gradio as gr | |
| import json | |
| import torch | |
| import re | |
| import random | |
| class TableDataGenerator: | |
| def __init__(self, model_name="Qwen/Qwen2.5-3B-Instruct"): | |
| self.model_name = model_name | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype="auto", | |
| device_map="auto" | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| def generate_batch_data(self, llm_commands, num_rows=1000, batch_size=30): | |
| """Generate table data in batches for better performance""" | |
| all_rows = [] | |
| # Create column headers description | |
| columns_desc = ", ".join([f"Column {i+1}: {cmd}" for i, cmd in enumerate(llm_commands)]) | |
| # Calculate number of batches | |
| num_batches = (num_rows + batch_size - 1) // batch_size | |
| for batch_idx in range(num_batches): | |
| current_batch_size = min(batch_size, num_rows - len(all_rows)) | |
| # Try multiple attempts to get enough rows for this batch | |
| batch_rows = [] | |
| max_attempts = 5 | |
| for attempt in range(max_attempts): | |
| remaining_needed = current_batch_size - len(batch_rows) | |
| if remaining_needed <= 0: | |
| break | |
| # Create prompt for this batch | |
| prompt = f"""Generate exactly {remaining_needed} rows of data for a table with columns: | |
| {columns_desc} | |
| Format: [['value1', 'value2'], ['value3', 'value4']] | |
| Requirements: | |
| - Each row must be different and realistic | |
| - Return ONLY the list, no explanations | |
| - Make data diverse and creative | |
| - Seed: {batch_idx * 10 + attempt} | |
| Generate exactly {remaining_needed} rows:""" | |
| messages = [ | |
| {"role": "system", "content": "You are a precise data generator. Return only valid Python list format with exactly the requested number of rows."}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| # Generate response | |
| response = self._generate_response(messages, batch_idx * 10 + attempt) | |
| # Parse the response to extract rows | |
| new_rows = self._parse_response(response, len(llm_commands)) | |
| # Add unique rows only | |
| for row in new_rows: | |
| if row not in batch_rows and row not in all_rows: | |
| batch_rows.append(row) | |
| if len(batch_rows) >= current_batch_size: | |
| break | |
| # Add to all rows | |
| all_rows.extend(batch_rows) | |
| # If we still don't have enough, generate fallback data | |
| if len(all_rows) < num_rows and len(batch_rows) < current_batch_size: | |
| fallback_rows = self._generate_fallback_data(llm_commands, current_batch_size - len(batch_rows), len(all_rows)) | |
| all_rows.extend(fallback_rows) | |
| # Break if we have enough rows | |
| if len(all_rows) >= num_rows: | |
| break | |
| return all_rows[:num_rows] | |
| def _generate_response(self, messages, seed=None): | |
| """Generate response from the model""" | |
| text = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) | |
| # Set random seed for variety | |
| if seed is not None: | |
| torch.manual_seed(seed) | |
| else: | |
| torch.manual_seed(random.randint(1, 10000)) | |
| generated_ids = self.model.generate( | |
| **model_inputs, | |
| max_new_tokens=300, | |
| temperature=0.9, | |
| do_sample=True, | |
| top_p=0.95, | |
| repetition_penalty=1.1 | |
| ) | |
| generated_ids = [ | |
| output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
| ] | |
| response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return response | |
| def _parse_response(self, response, expected_columns): | |
| """Parse the model response to extract table rows""" | |
| rows = [] | |
| try: | |
| # Try to find list-like patterns in the response | |
| # Look for patterns like [['value1', 'value2'], ['value3', 'value4']] | |
| list_pattern = r'\[\s*\[.*?\]\s*\]' | |
| matches = re.findall(list_pattern, response, re.DOTALL) | |
| if matches: | |
| # Try to evaluate the largest match | |
| largest_match = max(matches, key=len) | |
| try: | |
| parsed_data = eval(largest_match) | |
| if isinstance(parsed_data, list): | |
| for row in parsed_data: | |
| if isinstance(row, list) and len(row) == expected_columns: | |
| rows.append([str(item) for item in row]) | |
| except: | |
| pass | |
| # If no valid list found, try to extract individual rows | |
| if not rows: | |
| # Look for individual row patterns like ['value1', 'value2'] | |
| row_pattern = r'\[([^\[\]]+)\]' | |
| row_matches = re.findall(row_pattern, response) | |
| for match in row_matches: | |
| try: | |
| # Split by comma and clean up | |
| items = [item.strip().strip('"\'') for item in match.split(',')] | |
| if len(items) == expected_columns: | |
| rows.append(items) | |
| except: | |
| continue | |
| except Exception as e: | |
| print(f"Error parsing response: {e}") | |
| return rows | |
| def _generate_fallback_data(self, llm_commands, needed_rows, current_count): | |
| """Generate fallback data when LLM doesn't produce enough rows""" | |
| fallback_rows = [] | |
| # Simple fallback generators based on command type | |
| for i in range(needed_rows): | |
| row = [] | |
| for cmd in llm_commands: | |
| cmd_lower = cmd.lower() | |
| if 'age' in cmd_lower: | |
| if 'between' in cmd_lower and '1' in cmd_lower and '20' in cmd_lower: | |
| row.append(str(random.randint(1, 20))) | |
| else: | |
| row.append(str(random.randint(18, 65))) | |
| elif 'arabic' in cmd_lower and 'name' in cmd_lower: | |
| arabic_names = ['محمد', 'أحمد', 'عبدالله', 'خالد', 'سعد', 'فهد', 'عبدالعزيز', 'ناصر', 'سلطان', 'طلال', | |
| 'فاطمة', 'عائشة', 'خديجة', 'مريم', 'زينب', 'سارة', 'نورا', 'هند', 'لطيفة', 'منى'] | |
| row.append(random.choice(arabic_names)) | |
| elif 'name' in cmd_lower: | |
| names = ['John', 'Jane', 'Michael', 'Sarah', 'David', 'Lisa', 'Robert', 'Emily', 'James', 'Jessica'] | |
| row.append(random.choice(names)) | |
| elif 'price' in cmd_lower or 'cost' in cmd_lower: | |
| row.append(str(random.randint(10, 1000))) | |
| elif 'city' in cmd_lower: | |
| cities = ['New York', 'London', 'Tokyo', 'Paris', 'Sydney', 'Cairo', 'Dubai', 'Berlin', 'Rome', 'Madrid'] | |
| row.append(random.choice(cities)) | |
| else: | |
| row.append(f"data_{current_count + i + 1}") | |
| fallback_rows.append(row) | |
| return fallback_rows | |
| def generate_table_data(json_input, num_rows=1000): | |
| """Main function to generate table data from JSON input""" | |
| try: | |
| # Parse JSON input | |
| data = json.loads(json_input) | |
| llm_commands = data.get('llm_commands', []) | |
| if not llm_commands: | |
| return "Error: No llm_commands found in JSON input", [] | |
| # Initialize generator | |
| generator = TableDataGenerator() | |
| # Generate data | |
| rows = generator.generate_batch_data(llm_commands, num_rows) | |
| # Create JSON structure | |
| json_data = { | |
| "columns": llm_commands, | |
| "rows": rows, | |
| "total_rows": len(rows) | |
| } | |
| # Save to JSON file with proper Arabic encoding | |
| import os | |
| os.makedirs('./train', exist_ok=True) | |
| json.dump(json_data, open('./train/data.json', "w", encoding="utf-8"), ensure_ascii=False, indent=2) | |
| # Format output | |
| result = f"Generated {len(rows)} rows (requested: {num_rows}):\n" | |
| result += f"Columns: {llm_commands}\n" | |
| result += f"Saved to: ./train/data.json\n\n" | |
| # Show first 10 rows as preview | |
| result += "First 10 rows:\n" | |
| for i, row in enumerate(rows[:10]): | |
| result += f"{i+1}: {row}\n" | |
| if len(rows) > 10: | |
| result += f"\n... and {len(rows) - 10} more rows" | |
| return result, json_data | |
| except json.JSONDecodeError: | |
| return "Error: Invalid JSON format", {} | |
| except Exception as e: | |
| return f"Error: {str(e)}", {} | |
| # Gradio Interface | |
| def process_json_input(json_input, num_rows): | |
| """Process JSON input and return formatted results""" | |
| result_text, json_data = generate_table_data(json_input, int(num_rows)) | |
| # Return JSON content for download | |
| if json_data and 'rows' in json_data: | |
| json_content = json.dumps(json_data, ensure_ascii=False, indent=2) | |
| return result_text, json_content | |
| else: | |
| return result_text, "" | |
| # Create Gradio interface | |
| with gr.Blocks(title="Table Data Generator") as demo: | |
| gr.Markdown("# Table Data Generator using LLM") | |
| gr.Markdown("Generate realistic table data based on column descriptions") | |
| with gr.Row(): | |
| with gr.Column(): | |
| json_input = gr.Textbox( | |
| label="JSON Input", | |
| placeholder='{"llm_commands": ["ages between 1 to 20", "arabic name"]}', | |
| lines=3, | |
| value='{"llm_commands": ["ages between 1 to 20", "arabic name"]}' | |
| ) | |
| num_rows = gr.Slider( | |
| minimum=10, | |
| maximum=2000, | |
| value=100, | |
| step=10, | |
| label="Number of rows to generate" | |
| ) | |
| generate_btn = gr.Button("Generate Data", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Generated Data Preview", | |
| lines=15, | |
| max_lines=20 | |
| ) | |
| download_json = gr.File( | |
| label="Download JSON", | |
| visible=True | |
| ) | |
| def generate_and_save(json_input, num_rows): | |
| result_text, json_content = process_json_input(json_input, num_rows) | |
| if json_content: | |
| # Save to temporary file | |
| import tempfile | |
| import os | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False, encoding='utf-8') as f: | |
| f.write(json_content) | |
| temp_path = f.name | |
| return result_text, temp_path | |
| else: | |
| return result_text, None | |
| generate_btn.click( | |
| fn=generate_and_save, | |
| inputs=[json_input, num_rows], | |
| outputs=[output_text, download_json] | |
| ) | |
| # Example inputs | |
| gr.Examples( | |
| examples=[ | |
| ['{"llm_commands": ["ages between 1 to 20", "arabic name"]}', 50], | |
| ['{"llm_commands": ["random city", "population number", "country"]}', 100], | |
| ['{"llm_commands": ["product name", "price in USD", "category"]}', 75], | |
| ['{"llm_commands": ["email address", "phone number", "job title"]}', 60] | |
| ], | |
| inputs=[json_input, num_rows] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |
| # Example usage: | |
| # json_input = '{"llm_commands": ["ages between 1 to 20", "arabic name"]}' | |
| # result_text, json_data = generate_table_data(json_input, 1000) | |
| # print(result_text) | |
| # print(f"Actual rows generated: {len(json_data.get('rows', []))}") |