oscarwang2 commited on
Commit
c5963a7
1 Parent(s): 0e46137

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -59
app.py CHANGED
@@ -1,39 +1,21 @@
1
- import gradio as gr
2
- import yfinance as yf
3
- import plotly.graph_objects as go
4
- from statsmodels.tsa.arima.model import ARIMA
5
  import pandas as pd
6
  import logging
 
 
 
 
7
 
8
- # Setup logging
9
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10
 
11
- def fetch_eth_price(period):
12
- eth = yf.Ticker("ETH-USD")
13
- if period == '1d':
14
- data = eth.history(period="1d", interval="1m")
15
- predict_steps = 60 # Next 60 minutes
16
- freq = 'min' # Minute frequency
17
- elif period == '5d':
18
- data = eth.history(period="5d", interval="15m")
19
- predict_steps = 96 # Next 24 hours
20
- freq = '15min' # 15 minutes frequency
21
- elif period == '1wk':
22
- data = eth.history(period="1wk", interval="30m")
23
- predict_steps = 336 # Next 7 days
24
- freq = '30min' # 30 minutes frequency
25
- elif period == '1mo':
26
- data = eth.history(period="1mo", interval="1h")
27
- predict_steps = 720 # Next 30 days
28
- freq = 'H' # Hourly frequency
29
- else:
30
- logging.error("Invalid period specified.")
31
- return None, None, None
32
-
33
- data.index = pd.DatetimeIndex(data.index)
34
- data = data.asfreq(freq) # Ensure the data has a consistent frequency
35
  logging.info(f"Fetched {len(data)} data points for the period {period}.")
36
- return data, predict_steps, freq
37
 
38
  def make_predictions(data, predict_steps, freq):
39
  if data is None or data.empty:
@@ -41,6 +23,12 @@ def make_predictions(data, predict_steps, freq):
41
  return None
42
 
43
  logging.info(f"Starting model training with {len(data)} data points...")
 
 
 
 
 
 
44
  try:
45
  model = ARIMA(data['Close'], order=(5, 1, 0))
46
  model_fit = model.fit()
@@ -48,21 +36,20 @@ def make_predictions(data, predict_steps, freq):
48
  except Exception as e:
49
  logging.error(f"Model training failed: {e}")
50
  return None
51
-
52
  logging.info("Model training completed.")
53
-
54
  logging.info("Generating predictions...")
55
  try:
56
  forecast = model_fit.forecast(steps=predict_steps)
 
 
 
57
  except Exception as e:
58
  logging.error(f"Prediction generation failed: {e}")
59
  return None
60
 
61
- if forecast.isnull().any():
62
- logging.error("Generated predictions contain NaN values.")
63
- return None
64
-
65
- future_dates = pd.date_range(start=data.index[-1], periods=predict_steps+1, freq=freq, inclusive='right')
66
  forecast_df = pd.DataFrame(forecast, index=future_dates[1:], columns=['Prediction'])
67
 
68
  logging.info(f"Forecast Data:\n{forecast_df.head()}")
@@ -70,34 +57,38 @@ def make_predictions(data, predict_steps, freq):
70
 
71
  return forecast_df
72
 
73
- def plot_eth(period):
74
- data, predict_steps, freq = fetch_eth_price(period)
75
- if data is None or predict_steps is None or freq is None:
76
- logging.error("Failed to fetch data or set up prediction parameters.")
77
- return None
78
-
79
  forecast_df = make_predictions(data, predict_steps, freq)
80
- if forecast_df is None or forecast_df.empty:
81
  logging.error("Failed to generate predictions.")
82
  return None
83
 
84
- fig = go.Figure()
85
- fig.add_trace(go.Scatter(x=data.index, y=data['Close'], mode='lines', name='ETH Price'))
86
- fig.add_trace(go.Scatter(x=forecast_df.index, y=forecast_df['Prediction'], mode='lines', name='Prediction', line=dict(dash='dash')))
87
- fig.update_layout(title=f"ETH Price and Predictions ({period})", xaxis_title="Date", yaxis_title="Price (USD)")
 
 
 
 
 
88
 
 
 
 
89
  logging.info("Plotting completed.")
90
- return fig
 
91
 
92
  def refresh_predictions(period):
93
- return plot_eth(period)
 
 
 
94
 
95
- with gr.Blocks() as iface:
96
- period = gr.Radio(["1d", "5d", "1wk", "1mo"], label="Select Period")
97
- plot = gr.Plot()
98
- refresh_button = gr.Button("Refresh Predictions and Prices")
99
-
100
- period.change(fn=plot_eth, inputs=period, outputs=plot)
101
- refresh_button.click(fn=refresh_predictions, inputs=period, outputs=plot)
102
-
103
  iface.launch()
 
1
+ import numpy as np
 
 
 
2
  import pandas as pd
3
  import logging
4
+ import matplotlib.pyplot as plt
5
+ from statsmodels.tsa.arima.model import ARIMA
6
+ import yfinance as yf
7
+ import gradio as gr
8
 
9
+ logging.basicConfig(level=logging.INFO)
 
10
 
11
+ def fetch_data(period='1d'):
12
+ logging.info(f"Fetching data for the period {period}...")
13
+ data = yf.download(tickers='ETH-USD', period=period, interval='1m')
14
+ if data.empty:
15
+ logging.error("No data fetched. Check the period or ticker symbol.")
16
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  logging.info(f"Fetched {len(data)} data points for the period {period}.")
18
+ return data
19
 
20
  def make_predictions(data, predict_steps, freq):
21
  if data is None or data.empty:
 
23
  return None
24
 
25
  logging.info(f"Starting model training with {len(data)} data points...")
26
+
27
+ # Check for NaN values in the data
28
+ if data['Close'].isna().any():
29
+ logging.error("Data contains NaN values. Please clean the data before model training.")
30
+ return None
31
+
32
  try:
33
  model = ARIMA(data['Close'], order=(5, 1, 0))
34
  model_fit = model.fit()
 
36
  except Exception as e:
37
  logging.error(f"Model training failed: {e}")
38
  return None
39
+
40
  logging.info("Model training completed.")
41
+
42
  logging.info("Generating predictions...")
43
  try:
44
  forecast = model_fit.forecast(steps=predict_steps)
45
+ if np.isnan(forecast).any():
46
+ logging.error("Generated predictions contain NaN values. Model might be improperly configured.")
47
+ return None
48
  except Exception as e:
49
  logging.error(f"Prediction generation failed: {e}")
50
  return None
51
 
52
+ future_dates = pd.date_range(start=data.index[-1], periods=predict_steps + 1, freq=freq, inclusive='right')
 
 
 
 
53
  forecast_df = pd.DataFrame(forecast, index=future_dates[1:], columns=['Prediction'])
54
 
55
  logging.info(f"Forecast Data:\n{forecast_df.head()}")
 
57
 
58
  return forecast_df
59
 
60
+ def plot_eth(period='1d'):
61
+ data = fetch_data(period)
62
+ predict_steps = 5 # Modify as needed
63
+ freq = 'T' # 'T' stands for minutes
64
+
 
65
  forecast_df = make_predictions(data, predict_steps, freq)
66
+ if forecast_df is None:
67
  logging.error("Failed to generate predictions.")
68
  return None
69
 
70
+ plt.figure(figsize=(10, 5))
71
+ plt.plot(data['Close'], label='Actual ETH Price')
72
+ plt.plot(forecast_df['Prediction'], label='Forecasted ETH Price', linestyle='dotted', color='orange')
73
+ plt.title('ETH Price Prediction')
74
+ plt.xlabel('Time')
75
+ plt.ylabel('Price (USD)')
76
+ plt.legend()
77
+ plt.grid(True)
78
+ plt.tight_layout()
79
 
80
+ # Save the plot to a file
81
+ plot_filename = '/home/user/app/eth_price_prediction.png'
82
+ plt.savefig(plot_filename)
83
  logging.info("Plotting completed.")
84
+
85
+ return plot_filename
86
 
87
  def refresh_predictions(period):
88
+ plot_filename = plot_eth(period)
89
+ if plot_filename is None:
90
+ return "Error in generating plot."
91
+ return plot_filename
92
 
93
+ iface = gr.Interface(fn=refresh_predictions, inputs="text", outputs="image", live=True)
 
 
 
 
 
 
 
94
  iface.launch()