| import pandas as pd |
| import chromadb |
| from sklearn.model_selection import train_test_split |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments, pipeline |
| import gradio as gr |
| import email |
|
|
| |
| emails = pd.read_csv('emails.csv') |
| def preprocess_email_content(raw_email): |
| message = email.message_from_string(raw_email).get_payload() |
| return message.replace("\n", "").replace("\r", "").replace("> >>> > >", "").strip() |
|
|
| content_text = [preprocess_email_content(item) for item in emails['message']] |
| train_content, _ = train_test_split(content_text, train_size=0.00005) |
|
|
| |
| client = chromadb.Client() |
| collection = client.create_collection(name="Enron_emails") |
| collection.add(documents=train_content, ids=[f'id{i+1}' for i in range(len(train_content))]) |
|
|
| |
| tokenizer = GPT2Tokenizer.from_pretrained('gpt2') |
| model = GPT2LMHeadModel.from_pretrained('gpt2') |
| tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
|
| |
| tokenized_emails = tokenizer(train_content, truncation=True, padding=True) |
| with open('tokenized_emails.txt', 'w') as file: |
| for ids in tokenized_emails['input_ids']: |
| file.write(' '.join(map(str, ids)) + '\n') |
|
|
| dataset = TextDataset(tokenizer=tokenizer, file_path='tokenized_emails.txt', block_size=128) |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
| training_args = TrainingArguments( |
| output_dir='./output', |
| num_train_epochs=3, |
| per_device_train_batch_size=8 |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| data_collator=data_collator, |
| train_dataset=dataset |
| ) |
| trainer.train() |
|
|
| |
| model.save_pretrained("./fine_tuned_model") |
| tokenizer.save_pretrained("./fine_tuned_model") |
|
|
| |
| def question_answer(question): |
| try: |
| generated = text_gen(question, max_length=200, num_return_sequences=1) |
| generated_text = generated[0]['generated_text'].replace(question, "").strip() |
| return generated_text |
| except Exception as e: |
| return f"Error in generating response: {str(e)}" |
|
|
| text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer) |
| iface = gr.Interface( |
| fn=question_answer, |
| inputs="text", |
| outputs="text", |
| title="Answering questions about the Enron case.", |
| description="Ask a question about the Enron case!", |
| examples=["What is Eron?"] |
| ) |
| iface.launch() |