| import streamlit as st
|
| import fitz
|
| import torch
|
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| import hashlib
|
| from langchain.text_splitter import CharacterTextSplitter
|
| from langchain.vectorstores import FAISS
|
| from langchain.embeddings import OllamaEmbeddings
|
|
|
|
|
|
|
| MODEL_PATH = "./fine_tuned_tinyllama_tax"
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
|
|
| model = AutoModelForCausalLM.from_pretrained(
|
| MODEL_PATH,
|
| torch_dtype=torch.float16,
|
| device_map="auto"
|
| )
|
|
|
| tax_llm = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
|
|
|
|
|
|
| if "legal_knowledge_base" not in st.session_state:
|
| st.session_state.legal_knowledge_base = ""
|
| if "vector_db" not in st.session_state:
|
| st.session_state.vector_db = None
|
| if "summary" not in st.session_state:
|
| st.session_state.summary = ""
|
| if "answer" not in st.session_state:
|
| st.session_state.answer = ""
|
|
|
|
|
|
|
| def compute_file_hash(file):
|
| """Computes SHA-256 hash of the uploaded file to track changes."""
|
| hasher = hashlib.sha256()
|
| hasher.update(file.read())
|
| file.seek(0)
|
| return hasher.hexdigest()
|
|
|
| def extract_text_from_pdf(pdf_file):
|
| """Extracts text from a PDF using PyMuPDF (fitz)."""
|
| doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
|
| pdf_file.seek(0)
|
| text = "\n".join([page.get_text("text") for page in doc])
|
| return text.strip() if text.strip() else "No extractable text found in PDF."
|
|
|
| def summarize_text(text):
|
| """Summarizes tax policy documents using fine-tuned AI."""
|
| prompt = f"Summarize this tax policy document concisely:\n{text}"
|
| summary = tax_llm(prompt, max_length=200, do_sample=True)[0]["generated_text"]
|
| return summary
|
|
|
| def create_vector_db():
|
| """Creates a searchable vector database from extracted legal documents."""
|
| text = st.session_state.legal_knowledge_base
|
| if not text:
|
| return None
|
|
|
| text_splitter = CharacterTextSplitter(separator="\n", chunk_size=1000, chunk_overlap=150)
|
| texts = text_splitter.split_text(text)
|
| embeddings = OllamaEmbeddings(model="llama3:8b")
|
| return FAISS.from_texts(texts, embeddings)
|
|
|
| def retrieve_relevant_text(query, vector_db):
|
| """Fetches relevant legal sections from the document."""
|
| if not vector_db:
|
| return "No document uploaded."
|
|
|
| docs = vector_db.similarity_search(query, k=5)
|
| retrieved_text = "\n".join([doc.page_content for doc in docs])
|
| return retrieved_text
|
|
|
| def compute_tax_details(query):
|
| """Extracts income & tax rate and calculates tax."""
|
| import re
|
|
|
| income_match = re.search(r"βΉ?(\d[\d,]*)", query.replace(",", ""))
|
| tax_rate_match = re.search(r"(\d+)%", query)
|
|
|
| if income_match and tax_rate_match:
|
| income = float(income_match.group(1).replace(",", ""))
|
| tax_rate = float(tax_rate_match.group(1))
|
|
|
| computed_tax = round(income * (tax_rate / 100), 2)
|
| return f"Based on an income of βΉ{income:,.2f} and a tax rate of {tax_rate}%, the tax is **βΉ{computed_tax:,.2f}.**"
|
|
|
| return None
|
|
|
| def answer_user_query(query):
|
| """Answers tax-related queries using the fine-tuned model."""
|
| tax_computation_result = compute_tax_details(query)
|
|
|
| if tax_computation_result:
|
| st.session_state.answer = tax_computation_result
|
| return
|
|
|
| if not st.session_state.vector_db:
|
| st.error("Please upload a document first.")
|
| return
|
|
|
| retrieved_text = retrieve_relevant_text(query, st.session_state.vector_db)
|
| prompt = f"""
|
| You are an AI tax expert. Use legal knowledge and tax calculations to answer.
|
|
|
| Context:
|
| {retrieved_text}
|
|
|
| User Query:
|
| {query}
|
|
|
| Response:
|
| """
|
|
|
| response = tax_llm(prompt, max_length=300, do_sample=True)[0]["generated_text"]
|
| st.session_state.answer = response
|
|
|
|
|
|
|
| def main():
|
| st.title("π AI Legal Tax Assistant")
|
|
|
| uploaded_file = st.file_uploader("π Upload Tax Policy PDF", type=["pdf"])
|
|
|
| if uploaded_file:
|
| with st.spinner("Extracting text..."):
|
| extracted_text = extract_text_from_pdf(uploaded_file)
|
| st.session_state.legal_knowledge_base = extracted_text
|
| st.success("Document Uploaded!")
|
|
|
| with st.spinner("Generating summary..."):
|
| st.session_state.summary = summarize_text(extracted_text)
|
| st.subheader("π Document Summary:")
|
| st.text_area("", st.session_state.summary, height=250)
|
|
|
| with st.spinner("Indexing document..."):
|
| st.session_state.vector_db = create_vector_db()
|
| st.success("Document indexed! Ask questions now.")
|
|
|
| st.subheader("π¬ Ask Questions:")
|
| user_query = st.text_input("Enter your question:")
|
|
|
| if st.button("Ask") and user_query.strip():
|
| with st.spinner("Processing..."):
|
| answer_user_query(user_query)
|
|
|
| if st.session_state.answer:
|
| st.markdown("### π€ AI Response:")
|
| st.success(st.session_state.answer)
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|