| import pandas as pd |
| import matplotlib.pyplot as plt |
| import joblib |
| import gradio as gr |
| from dateutil.relativedelta import relativedelta |
| import calendar |
|
|
| def load_model(): |
| try: |
| model = joblib.load('arima_sales_model.pkl') |
| return model, None |
| except Exception as e: |
| return None, f"Failed to load model: {str(e)}" |
|
|
| def parse_date(date_str): |
| """Parse the custom date format 'Month-Year'.""" |
| try: |
| date = pd.to_datetime(date_str, format="%B-%Y") |
| _, last_day = calendar.monthrange(date.year, date.month) |
| start_date = date.replace(day=1) |
| end_date = date.replace(day=last_day) |
| return start_date, end_date, None |
| except ValueError: |
| return None, None, "Date format should be 'Month-Year', e.g., 'January-2024'." |
|
|
| def forecast_sales(uploaded_file, start_date_str, end_date_str): |
| if uploaded_file is None: |
| return "No file uploaded.", None, "Please upload a file." |
|
|
| try: |
| df = pd.read_csv(uploaded_file) |
| if 'Date' not in df.columns or 'Sale' not in df.columns: |
| return None, "The uploaded file must contain 'Date' and 'Sale' columns.", "File does not have required columns." |
| except Exception as e: |
| return None, f"Failed to read the uploaded CSV file: {str(e)}", "Error reading file." |
|
|
| start_date, _, error = parse_date(start_date_str) |
| _, end_date, error_end = parse_date(end_date_str) |
| if error or error_end: |
| return None, error or error_end, "Invalid date format." |
|
|
| df['Date'] = pd.to_datetime(df['Date']) |
| df = df.rename(columns={'Date': 'ds', 'Sale': 'y'}) |
|
|
| df_filtered = df[(df['ds'] >= start_date) & (df['ds'] <= end_date)] |
|
|
| arima_model, error = load_model() |
| if arima_model is None: |
| return None, error, "Failed to load ARIMA model." |
|
|
| try: |
| forecast = arima_model.get_forecast(steps=60) |
| forecast_index = pd.date_range(start=end_date, periods=61, freq='D')[1:] |
| forecast_df = pd.DataFrame({'Date': forecast_index, 'Sales Forecast': forecast.predicted_mean}) |
| |
| fig, ax = plt.subplots(figsize=(10, 6)) |
| ax.plot(df_filtered['ds'], df_filtered['y'], label='Actual Sales', color='blue') |
| ax.plot(forecast_df['Date'], forecast_df['Sales Forecast'], label='Sales Forecast', color='red', linestyle='--') |
| ax.set_xlabel('Date') |
| ax.set_ylabel('Sales') |
| ax.set_title('Sales Forecasting with ARIMA') |
| ax.legend() |
| return fig, "File loaded and processed successfully." |
| except Exception as e: |
| return None, f"Failed to generate plot: {str(e)}", "Plotting failed." |
|
|
| def setup_interface(): |
| with gr.Blocks() as demo: |
| gr.Markdown("## MLCast v1.1 - Intelligent Sales Forecasting System") |
| with gr.Row(): |
| with gr.Column(scale=1): |
| file_input = gr.File(label="Upload your store data") |
| start_date_input = gr.Textbox(label="Start Date", placeholder="January-2024") |
| end_date_input = gr.Textbox(label="End Date", placeholder="December-2024") |
| forecast_button = gr.Button("Forecast Sales") |
| with gr.Column(scale=2): |
| output_plot = gr.Plot() |
| output_message = gr.Textbox(label="Notifications", visible=True, lines=2) |
| forecast_button.click( |
| forecast_sales, |
| inputs=[file_input, start_date_input, end_date_input], |
| outputs=[output_plot, output_message] |
| ) |
| return demo |
|
|
| if __name__ == "__main__": |
| interface = setup_interface() |
| interface.launch() |
|
|