| |
| """satellite_app.ipynb |
| |
| Automatically generated by Colab. |
| |
| Original file is located at |
| https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27 |
| """ |
|
|
| import gradio as gr |
| from safetensors.torch import load_model |
| from timm import create_model |
| from huggingface_hub import hf_hub_download |
| from datasets import load_dataset |
| import torch |
| import torchvision.transforms as T |
| import cv2 |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from PIL import Image |
| import os |
|
|
| from langchain_community.document_loaders import TextLoader |
| from langchain_community.vectorstores import FAISS |
| from langchain_community.embeddings import HuggingFaceEmbeddings |
| from langchain.text_splitter import CharacterTextSplitter |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_core.runnables import RunnablePassthrough |
| from langchain_fireworks import ChatFireworks |
| from langchain_core.prompts import ChatPromptTemplate |
| from transformers import AutoModelForImageClassification, AutoImageProcessor |
|
|
|
|
| safe_tensors = "model.safetensors" |
|
|
| model_name = 'swin_s3_base_224' |
| |
| model = create_model( |
| model_name, |
| num_classes=17 |
| ) |
|
|
| load_model(model,safe_tensors) |
|
|
| def one_hot_decoding(labels): |
| class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze'] |
| id2label = {idx:c for idx,c in enumerate(class_names)} |
|
|
| id_list = [] |
| for idx,i in enumerate(labels): |
| if i == 1: |
| id_list.append(idx) |
|
|
| true_labels = [] |
| for i in id_list: |
| true_labels.append(id2label[i]) |
| return true_labels |
|
|
| def ragChain(): |
| """ |
| function: creates a rag chain |
| output: rag chain |
| """ |
| loader = TextLoader("document.txt") |
| docs = loader.load() |
|
|
| text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
| docs = text_splitter.split_documents(docs) |
|
|
| vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True) |
| retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5}) |
|
|
| api_key = os.getenv("FIREWORKS_API_KEY") |
| llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key) |
|
|
| prompt = ChatPromptTemplate.from_messages( |
| [ |
| ( |
| "system", |
| """You are a knowledgeable landscape deforestation analyst. |
| """ |
| ), |
| ( |
| "human", |
| """First mention the detected labels only with short description. |
| Provide not more than 4 precautionary measures which are related to the detected labels that can be taken to control deforestation. |
| Don't include conversational messages. |
| """, |
| ), |
|
|
| ("human", "{context}, {question}"), |
| ] |
| ) |
|
|
| rag_chain = ( |
| { |
| "context": retriever, |
| "question": RunnablePassthrough() |
| } |
| | prompt |
| | llm |
| | StrOutputParser() |
| ) |
|
|
| return rag_chain |
|
|
| def model_output(image): |
|
|
| PIL_image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
|
| img_size = (224,224) |
| test_tfms = T.Compose([ |
| T.Resize(img_size), |
| T.ToTensor(), |
| ]) |
|
|
| img = test_tfms(PIL_image) |
|
|
| with torch.no_grad(): |
| logits = model(img.unsqueeze(0)) |
|
|
| predictions = logits.sigmoid() > 0.5 |
| predictions = predictions.float().numpy().flatten() |
| pred_labels = one_hot_decoding(predictions) |
| output_text = " ".join(pred_labels) |
|
|
| query = f"Detected labels in the provided satellite image are {output_text}. Give information on the labels." |
|
|
| return query |
|
|
| def generate_response(rag_chain, query): |
| """ |
| input: rag chain, query |
| function: generates response using llm and knowledge base |
| output: generated response by the llm |
| """ |
| return rag_chain.invoke(f"{query}") |
| |
| def main(image): |
| query = model_output(image) |
| chain = ragChain() |
| output = generate_response(chain, query) |
| return output |
| title = "Satellite Image Landscape Analysis for Deforestation" |
| description = "This bot will take any satellite image and analyze the factors which lead to deforestation by identify the landscape based on forest areas, roads, habitation, water etc." |
| app = gr.Interface(fn=main, inputs="image", outputs="text", title=title, |
| description=description, |
| examples=[["sample_images/train_142.jpg"], ["sample_images/train_32.jpg"],["sample_images/random_satellite3.png"],["sample_images/random_satellite2.png"],["sample_images/train_75.jpg"],["sample_images/train_92.jpg"],["sample_images/random_satellite.png"]]) |
| app.launch(share = True) |
|
|
|
|