| from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments |
| from datasets import load_dataset, Dataset, DatasetDict |
| from config import Config |
| import torch |
| from sklearn.model_selection import train_test_split |
| import pandas as pd |
|
|
| class CyberAttackDetectionModel: |
| def __init__(self): |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(Config.TOKENIZER_NAME) |
| self.model = AutoModelForCausalLM.from_pretrained(Config.MODEL_NAME) |
| self.model.to(Config.DEVICE) |
| |
| def preprocess_data(self, dataset): |
| """ |
| Preprocess the raw text dataset by cleaning and tokenizing. |
| """ |
| |
| def clean_text(text): |
| |
| |
| text = text.lower() |
| text = text.replace("\n", " ") |
| return text |
| |
| |
| dataset = dataset.map(lambda x: {'text': clean_text(x['text'])}) |
| |
| |
| def tokenize_function(examples): |
| return self.tokenizer(examples['text'], truncation=True, padding='max_length', max_length=Config.MAX_LENGTH) |
| |
| |
| tokenized_dataset = dataset.map(tokenize_function, batched=True) |
| |
| |
| tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) |
| |
| return tokenized_dataset |
| |
| def fine_tune(self, datasets): |
| """ |
| Fine-tune the model with the preprocessed datasets. |
| """ |
| |
| dataset_dict = DatasetDict({ |
| "train": datasets['train'], |
| "validation": datasets['validation'], |
| }) |
| |
| |
| training_args = TrainingArguments( |
| output_dir=Config.OUTPUT_DIR, |
| evaluation_strategy="epoch", |
| learning_rate=Config.LEARNING_RATE, |
| per_device_train_batch_size=Config.BATCH_SIZE, |
| per_device_eval_batch_size=Config.BATCH_SIZE, |
| weight_decay=Config.WEIGHT_DECAY, |
| save_total_limit=3, |
| num_train_epochs=Config.NUM_EPOCHS, |
| logging_dir=Config.LOGGING_DIR, |
| load_best_model_at_end=True |
| ) |
| |
| |
| trainer = Trainer( |
| model=self.model, |
| args=training_args, |
| train_dataset=dataset_dict['train'], |
| eval_dataset=dataset_dict['validation'], |
| ) |
| |
| |
| trainer.train() |
| |
| def predict(self, prompt): |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=Config.MAX_LENGTH) |
| inputs = {key: value.to(Config.DEVICE) for key, value in inputs.items()} |
| |
| outputs = self.model.generate(**inputs, max_length=Config.MAX_LENGTH) |
| return self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| def load_and_process_datasets(self): |
| """ |
| Loads and preprocesses the datasets for fine-tuning. |
| """ |
| |
| osint_datasets = [ |
| 'gonferspanish/OSINT', |
| 'Inforensics/missing-persons-clue-analysis-osint', |
| 'jester6136/osint', |
| 'originalbox/osint' |
| ] |
| |
| wrn_datasets = [ |
| 'WhiteRabbitNeo/WRN-Chapter-2', |
| 'WhiteRabbitNeo/WRN-Chapter-1', |
| 'WhiteRabbitNeo/Code-Functions-Level-Cyber' |
| ] |
| |
| |
| combined_datasets = [] |
| |
| |
| for dataset_name in osint_datasets: |
| dataset = load_dataset(dataset_name) |
| processed_data = self.preprocess_data(dataset['train']) |
| combined_datasets.append(processed_data) |
| |
| |
| for dataset_name in wrn_datasets: |
| dataset = load_dataset(dataset_name) |
| processed_data = self.preprocess_data(dataset['train']) |
| combined_datasets.append(processed_data) |
| |
| |
| full_dataset = DatasetDict() |
| full_dataset['train'] = Dataset.from_dict(pd.concat([d['train'] for d in combined_datasets])) |
| full_dataset['validation'] = Dataset.from_dict(pd.concat([d['validation'] for d in combined_datasets])) |
| |
| return full_dataset |
|
|
| if __name__ == "__main__": |
| |
| model = CyberAttackDetectionModel() |
|
|
| |
| preprocessed_datasets = model.load_and_process_datasets() |
|
|
| |
| model.fine_tune(preprocessed_datasets) |
|
|
| |
| prompt = "A network scan reveals an open port 22 with an outdated SSH service." |
| print(model.predict(prompt)) |
|
|