import gradio as gr import pandas as pd from prophet import Prophet import plotly.graph_objs as go import re import logging import os import torch from chronos import ChronosPipeline import numpy as np import requests import tempfile from clickhouse_driver import Client try: from google.colab import userdata PG_PASSWORD = userdata.get('FASHION_PG_PASS') CH_PASSWORD = userdata.get('FASHION_CH_PASS') except: PG_PASSWORD = os.environ['FASHION_PG_PASS'] CH_PASSWORD = os.environ['FASHION_CH_PASS'] logging.getLogger("prophet").setLevel(logging.WARNING) logging.getLogger("cmdstanpy").setLevel(logging.WARNING) # Dictionary to map Russian month names to month numbers russian_months = { "январь": "01", "февраль": "02", "март": "03", "апрель": "04", "май": "05", "июнь": "06", "июль": "07", "август": "08", "сентябрь": "09", "октябрь": "10", "ноябрь": "11", "декабрь": "12" } def read_and_process_file(file): # Read the first three lines as a single text string with open(file.name, 'r') as f: first_three_lines = ''.join([next(f) for _ in range(3)]) # Check for "Неделя" or "Week" (case-insensitive) if not any(word in first_three_lines.lower() for word in ["неделя", "week"]): period_type = "Month" else: period_type = "Week" # Read the file again to process it with open(file.name, 'r') as f: lines = f.readlines() # Check if the second line is empty if lines[1].strip() == '': source = 'Google' data = pd.read_csv(file.name, skiprows=2) # Replace any occurrences of "<1" with 0 else: source = 'Yandex' data = pd.read_csv(file.name, sep=';', skiprows=0, usecols=[0, 2]) if period_type == "Month": # Replace Russian months with yyyy-MM format data.iloc[:, 0] = data.iloc[:, 0].apply(lambda x: re.sub(r'(\w+)\s(\d{4})', lambda m: f'{m.group(2)}-{russian_months[m.group(1).lower()]}', x) + '-01') if period_type == "Week": data.iloc[:, 0] = pd.to_datetime(data.iloc[:, 0], format="%d.%m.%Y") # Replace any occurrences of "<1" with 0 data.iloc[:, 1] = data.iloc[:, 1].apply(str).str.replace('<1', '0').str.replace(' ', '').str.replace(',', '.').astype(float) # Process the date column and set it as the index period_col = data.columns[0] data[period_col] = pd.to_datetime(data[period_col]) data.set_index(period_col, inplace=True) return data, period_type, period_col def get_data_from_db(query): # conn = psycopg2.connect( # dbname="kroyscappingdb", # user="read_only", # password=PG_PASSWORD, # host="rc1d-vbh2dw5ha0gpsazk.mdb.yandexcloud.net", # port="6432", # sslmode="require" # ) cert_data = requests.get('https://storage.yandexcloud.net/cloud-certs/RootCA.pem').text with tempfile.NamedTemporaryFile(delete=False) as temp_cert_file: temp_cert_file.write(cert_data.encode()) cert_file_path = temp_cert_file.name client = Client(host='rc1d-a93v7vf0pjfr6e2o.mdb.yandexcloud.net', port = 9440, user='user1', password=CH_PASSWORD, database='db1', secure=True, ca_certs=cert_file_path) # data = pd.read_sql_query(query, conn) result, columns = client.execute(query, with_column_types=True) column_names = [col[0] for col in columns] data = pd.DataFrame(result, columns=column_names) # conn.close() return data def forecast_time_series(file, product_name, wb, ozon, model_choice): if file is None: # Construct the query marketplaces = [] if wb: marketplaces.append('wildberries') if ozon: marketplaces.append('ozon') mp_filter = "', '".join(marketplaces) # query = f""" # select # to_char(dm.end_date, 'yyyy-mm-dd') as ds, # 1.0*sum(turnover) / (max(sum(turnover)) over ()) as y # from v_datamart dm # where {product_name} # and mp in ('{mp_filter}') # group by ds # order by ds # """ query = f""" select cast(start_date as date) as ds, 1.0*sum(turnover) / (max(sum(turnover)) over ()) as y from datamart_all_1 join week_data using (id_week) where {product_name} and mp in ('{mp_filter}') group by ds order by ds """ print(query) data = get_data_from_db(query) print(data) period_type = "Week" period_col = "ds" if len(data)==0: raise gr.Error("No data found in database. Please adjust filters") data.iloc[:, 0] = pd.to_datetime(data.iloc[:, 0], format='%Y-%m-%d') data.set_index('ds', inplace=True) else: data, period_type, period_col = read_and_process_file(file) if period_type == "Month": year = 12 n_periods = 24 freq = "MS" else: year = 52 n_periods = year * 2 freq = "W" df = data.reset_index().rename(columns={period_col: 'ds', data.columns[0]: 'y'}) if model_choice == "Prophet": forecast, yoy_change = forecast_prophet(df, n_periods, freq, year) elif model_choice == "Chronos": forecast, yoy_change = forecast_chronos(df, n_periods, freq, year) else: raise ValueError("Invalid model choice") # Create Plotly figure (common for both models) fig = create_plot(data, forecast) # Combine original data and forecast combined_df = pd.concat([data, forecast.set_index('ds')], axis=1) # Save combined data combined_file = 'combined_data.csv' combined_df.to_csv(combined_file) return fig, f'Year-over-Year Change in Sum of Values: {yoy_change:.2%}', combined_file def forecast_prophet(df, n_periods, freq, year): model = Prophet() model.fit(df) future = model.make_future_dataframe(periods=n_periods, freq=freq) forecast = model.predict(future) sum_last_year_original = df['y'].iloc[-year:].sum() sum_first_year_forecast = forecast['yhat'].iloc[-n_periods:-n_periods + year].sum() yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original return forecast, yoy_change def forecast_chronos(df, n_periods, freq, year): pipeline = ChronosPipeline.from_pretrained( "amazon/chronos-t5-mini", device_map="cpu", torch_dtype=torch.bfloat16, ) # Check for non-numeric values if not pd.api.types.is_numeric_dtype(df['y']): non_numeric = df[pd.to_numeric(df['y'], errors='coerce').isna()] if not non_numeric.empty: error_message = f"Non-numeric values found in 'y' column. First few problematic rows:\n{non_numeric.head().to_string()}" raise ValueError(error_message) try: y_values = df['y'].values.astype(np.float32) except ValueError as e: raise ValueError(f"Unable to convert 'y' column to float32: {str(e)}") chronos_forecast = pipeline.predict( context=torch.tensor(y_values), prediction_length=n_periods, num_samples=20, limit_prediction_length=False ) forecast_index = pd.date_range(start=df['ds'].iloc[-1], periods=n_periods+1, freq=freq)[1:] low, median, high = np.quantile(chronos_forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) forecast = pd.DataFrame({ 'ds': forecast_index, 'yhat': median, 'yhat_lower': low, 'yhat_upper': high }) sum_last_year_original = df['y'].iloc[-year:].sum() sum_first_year_forecast = median[:year].sum() yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original return forecast, yoy_change def create_plot(data, forecast): fig = go.Figure() fig.add_trace(go.Scatter(x=data.index, y=data.iloc[:, 0], mode='lines', name='Observed')) fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast', line=dict(color='red'))) fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_lower'], fill=None, mode='lines', line=dict(color='pink'), name='Lower CI')) fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_upper'], fill='tonexty', mode='lines', line=dict(color='pink'), name='Upper CI')) fig.update_layout( title='Observed Time Series and Forecast with Confidence Intervals', xaxis_title='Date', yaxis_title='Values', legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1), hovermode='x unified' ) return fig # Create Gradio interface using Blocks with gr.Blocks(theme=gr.themes.Monochrome()) as interface: gr.Markdown("# Time Series Forecasting") gr.Markdown("Upload a CSV file with a time series to forecast the next 2 years and see the YoY % change. Download the combined original and forecast data.") with gr.Row(): file_input = gr.File(label="Upload Time Series CSV") with gr.Row(): wb_checkbox = gr.Checkbox(label="Wildberries", value=True) ozon_checkbox = gr.Checkbox(label="Ozon", value=True) with gr.Row(): product_name_input = gr.Textbox(label="Product Name Filter", value="name like '%пуховик%'") with gr.Row(): model_choice = gr.Radio(["Prophet", "Chronos"], label="Choose Model", value="Prophet") with gr.Row(): compute_button = gr.Button("Compute") with gr.Row(): plot_output = gr.Plot(label="Time Series + Forecast Chart") with gr.Row(): yoy_output = gr.Text(label="YoY % Change") with gr.Row(): csv_output = gr.File(label="Download Combined Data CSV") compute_button.click( forecast_time_series, inputs=[file_input, product_name_input, wb_checkbox, ozon_checkbox, model_choice], outputs=[plot_output, yoy_output, csv_output] ) # Launch the interface interface.launch(debug=True)