# Consumer import time import pika import os from Server import get_response import json from agent.agent_graph.StateTasks import ProblemState import argparse import redis from encryption_utils import decrypt_token_from_json ################################################## # VARIABLES ################################################## # args for this file argparse_model = argparse.ArgumentParser() argparse_model.add_argument("--id", type=int, default=0, help="Consumer ID") consumer_id = argparse_model.parse_args().id RABBITMQ_URL = os.environ["RABBITMQ_URL"] QUEUE_NAME = os.environ["QUEUE_NAME"] redis_host = os.environ["REDIS_HOST"] redis_port = os.environ["REDIS_PORT"] redis_password = os.environ["REDIS_PASSWORD"] ################################################## # PROCESSING METHODS ################################################## def redis_send(user_id,msg_id,answer): r = redis.Redis( host=redis_host, port=redis_port, decode_responses=True, username="default", password=redis_password, ) success = r.set(f'ANSWER_FOR_USER_ID{user_id}_OF_{msg_id}',json.dumps(answer)) return success def model_call(request,token): # fill with last state try: state = json.loads(request['last_state']) except Exception: state: ProblemState = { "question": request['prompt'], "memory": request['memory'] } answer = get_response(request['prompt'], request['memory'],token,state,request['user_email'],request['user_name']) # drop unserlizable keys for k in ["llm","rag_model"]: answer[k] = "" return answer def process_message(recieved_msg): # decrypt token token = decrypt_token_from_json(json.loads(recieved_msg['ht_token_encrypted_dumped'])) # call the model model_answer = model_call(recieved_msg,token) # send answer to redis user_id = recieved_msg["user_id"] msg_id = recieved_msg["msg_id"] redis_send_res = redis_send(user_id,msg_id,model_answer) print({"STATUS": redis_send_res , "CONSUMER": {consumer_id}}) # add monitoring but still hide user data ################################################## # CONSUMER METHODS ################################################## def get_connection(): params = pika.URLParameters(RABBITMQ_URL) return pika.BlockingConnection(params) def callback(ch, method, properties, body): ##### Recieve message and process it recieved_msg = json.loads(body.decode()) print("-------------------------------------------------") print(f"MSG AT CONSUMER {consumer_id}" ) ##### Process Message process_message(recieved_msg) ###### Finalize ch.basic_ack(delivery_tag=method.delivery_tag) def start_consumer(): # when scalled each server has consumer params = pika.URLParameters(RABBITMQ_URL) connection = pika.BlockingConnection(params) channel = connection.channel() channel.queue_declare(queue=QUEUE_NAME, durable=True) channel.basic_qos(prefetch_count=1) channel.basic_consume( queue=QUEUE_NAME, on_message_callback=callback ) print("Waiting for messages...") channel.start_consuming() ################################################## # MAIN ################################################## if __name__ == "__main__": print(f"Starting New Consumer {consumer_id}...") start_consumer()