본문 바로가기
파이썬

파이썬 - ARIMA predict 함수 오류 기록

by Tiabet 2023. 2. 8.

https://tiabet0929.tistory.com/10

 

이 글을 포스팅하면서 발생한 오류를 해결하는 데에 아주 애를 먹었다. 원인을 알아내고 해결하는 데에 꼬박 2일이 걸렸고, stackoverflow 같은 곳에서도 명쾌한 해답을 얻기가 어려웠었기 때문에 따로 포스팅하고자 한다.

우선 발생한 오류는  The start argument could not be matchted to a location related to the index of the data 이다.

발생한 오류

나는 우선 처음에 ACF와 PACF를 확인하고자 했고, 그래서 늘 하던대로 statsmodels 패키지의 plot_acf 함수를 사용하는 코드를 짰다.

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.api as sm
 
# Load the data into a Pandas DataFrame
data = sm.datasets.co2.load_pandas()
df = data.data
 
# Convert the data to a time series format
ts = df['co2'].resample('M').mean()
 
print(ts)
 
# Plot the time series data
plt.plot(ts)
plt.xlabel("Year")
plt.ylabel("CO2 Concentration (ppm)")
plt.title("Monthly Mean CO2 Concentration")
plt.show()
 
# Plot the ACF and PACF of the time series data
sm.graphics.tsa.plot_acf(ts)
plt.show()
sm.graphics.tsa.plot_pacf(ts)
plt.show()
 
cs

이런 코드였는데, 결과가 이렇게 나왔다.

내 경험상 ACF 와 PACF 가 이렇게 나오는 이유는 중간에 Null 값이 존재하기 때문이어서, 나는 ts.dropna() 함수를 써서 null 값을 제거해주고 알맞은 그래프들을 얻을 수 있었다. print(ts)를 해봐도 NaN 값이 존재했고 이를 dropna 를 통해 제거했음을 알 수 있었다.

왼쪽이 dropna를 하기 전, 오른쪽이 하고 난 후

이제 SARIMAX 모델 예측을 하기 위해 역시 아래와 같은 코드를 실행했다. 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.api as sm
from statsmodels.tsa.statespace.sarimax import SARIMAX
 
 
# Split the data into training and testing sets
train = ts[:-12]
test = ts[-12:]
 
# Fit the SARIMA model to the training data
model = SARIMAX(train, order=(1,1,1), seasonal_order=(1,1,0,12))
results = model.fit()
 
# Use the SARIMA model to make predictions for the testing data
predictions = results.predict(start=test.index[0], end=test.index[-1], dynamic=False)
 
# Plot the actual and predicted values for the testing data
plt.plot(train, label="Training Data")
plt.plot(test, label="Actual Values",linewidth='3')
plt.plot(predictions, label="Predicted Values")
plt.xlabel("Year")
plt.ylabel("CO2 Concentration (ppm)")
plt.title("SARIMA Model")
plt.legend()
plt.show()
 
cs

그리고 실행을 했더니,

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
C:\Anaconda\lib\site-packages\pandas\_libs\index.pyx in pandas._libs.index.DatetimeEngine.get_loc()

pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.Int64HashTable.get_item()

pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.Int64HashTable.get_item()

KeyError: 980899200000000000

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
C:\Anaconda\lib\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance)
   3628             try:
-> 3629                 return self._engine.get_loc(casted_key)
   3630             except KeyError as err:

C:\Anaconda\lib\site-packages\pandas\_libs\index.pyx in pandas._libs.index.DatetimeEngine.get_loc()

C:\Anaconda\lib\site-packages\pandas\_libs\index.pyx in pandas._libs.index.DatetimeEngine.get_loc()

KeyError: Timestamp('2001-01-31 00:00:00')

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
C:\Anaconda\lib\site-packages\pandas\core\indexes\datetimes.py in get_loc(self, key, method, tolerance)
    695         try:
--> 696             return Index.get_loc(self, key, method, tolerance)
    697         except KeyError as err:

C:\Anaconda\lib\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance)
   3630             except KeyError as err:
-> 3631                 raise KeyError(key) from err
   3632             except TypeError:

KeyError: Timestamp('2001-01-31 00:00:00')

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
C:\Anaconda\lib\site-packages\statsmodels\tsa\base\tsa_model.py in get_index_label_loc(key, index, row_labels)
    246             if not isinstance(key, (int, np.integer)):
--> 247                 loc = row_labels.get_loc(key)
    248             else:

C:\Anaconda\lib\site-packages\pandas\core\indexes\datetimes.py in get_loc(self, key, method, tolerance)
    697         except KeyError as err:
--> 698             raise KeyError(orig_key) from err
    699 

KeyError: Timestamp('2001-01-31 00:00:00')

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
C:\Anaconda\lib\site-packages\statsmodels\tsa\base\tsa_model.py in get_prediction_index(start, end, nobs, base_index, index, silent, index_none, index_generated, data)
    355     try:
--> 356         start, _, start_oos = get_index_label_loc(
    357             start, base_index, data.row_labels

C:\Anaconda\lib\site-packages\statsmodels\tsa\base\tsa_model.py in get_index_label_loc(key, index, row_labels)
    278         except:
--> 279             raise e
    280     return loc, index, index_was_expanded

C:\Anaconda\lib\site-packages\statsmodels\tsa\base\tsa_model.py in get_index_label_loc(key, index, row_labels)
    242     try:
--> 243         loc, index, index_was_expanded = get_index_loc(key, index)
    244     except KeyError as e:

C:\Anaconda\lib\site-packages\statsmodels\tsa\base\tsa_model.py in get_index_loc(key, index)
    192         except (IndexError, ValueError) as e:
--> 193             raise KeyError(str(e))
    194         loc = key

KeyError: 'only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices'

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_16896\2562130860.py in <module>
     15 
     16 # Use the SARIMA model to make predictions for the testing data
---> 17 predictions = results.predict(start=test.index[0], end=test.index[-1], dynamic=False)
     18 
     19 # Plot the actual and predicted values for the testing data

C:\Anaconda\lib\site-packages\statsmodels\base\wrapper.py in wrapper(self, *args, **kwargs)
    111             obj = data.wrap_output(func(results, *args, **kwargs), how[0], how[1:])
    112         elif how:
--> 113             obj = data.wrap_output(func(results, *args, **kwargs), how)
    114         return obj
    115 

C:\Anaconda\lib\site-packages\statsmodels\tsa\statespace\mlemodel.py in predict(self, start, end, dynamic, **kwargs)
   3401         """
   3402         # Perform the prediction
-> 3403         prediction_results = self.get_prediction(start, end, dynamic, **kwargs)
   3404         return prediction_results.predicted_mean
   3405 

C:\Anaconda\lib\site-packages\statsmodels\tsa\statespace\mlemodel.py in get_prediction(self, start, end, dynamic, index, exog, extend_model, extend_kwargs, **kwargs)
   3285         # Handle start, end, dynamic
   3286         start, end, out_of_sample, prediction_index = (
-> 3287             self.model._get_prediction_index(start, end, index))
   3288 
   3289         # Handle `dynamic`

C:\Anaconda\lib\site-packages\statsmodels\tsa\base\tsa_model.py in _get_prediction_index(self, start, end, index, silent)
    832         """
    833         nobs = len(self.endog)
--> 834         return get_prediction_index(
    835             start,
    836             end,

C:\Anaconda\lib\site-packages\statsmodels\tsa\base\tsa_model.py in get_prediction_index(start, end, nobs, base_index, index, silent, index_none, index_generated, data)
    358         )
    359     except KeyError:
--> 360         raise KeyError(
    361             "The `start` argument could not be matched to a"
    362             " location related to the index of the data."

KeyError: 'The `start` argument could not be matched to a location related to the index of the data.'

위와같은 어마어마하게 긴 에러코드가 발생했다.

predictions = results.predict(start=test.index[0], end=test.index[-1], dynamic=False)

KeyError: 'The `start` argument could not be matched to a location related to the index of the data.'

 

아무튼 중요한 부분은 위 코드에서 에러가 일어난다는 점인데, 도저히 이 말을 이해할 수가 없었다. stackoverflow 도 뒤져보고, 에러코드를 자세히 읽어본 결과 아무튼 index가 뭔가 잘못되었다는 점이었는데, 이를 해결하기 위해 많은 것을 시도해보았다.

그렇게 몇시간을 날렸지만, 도저히 원인을 모르겠었는데 단서가 하나 있었다. 

 

predictions = results.predict(start=train.index[0], end=train.index[-1], dynamic=False)

이 predictions을 test 대신 train 으로 고치면 그래프가 이상하게나마 나오긴 한다는 것이다.

 

이 때문에 더더욱 갈피를 못 잡고 길을 헤메었다. Index가 date나 time으로 인식을 못 해서 발생하는 오류라고 생각했기 때문이다. 그래서 chatGPT 에도 그렇게 물어봤지만 원하는 답을 얻을 수 없었다. (자꾸 시간 형식이 달라서 그럴 거라고 함) 그렇게 한참 헤매고 있다가 다시 stackoverflow를 뒤져서 해결의 실마리를 찾을 수 있었다.

https://stackoverflow.com/questions/58580633/the-start-argument-could-not-be-matched-to-a-location-related-to-the-index-of

 

The `start` argument could not be matched to a location related to the index of the data

I don't know why my 'start' pred won't work. I added some edits to pd.to_datetime but they didn't work. This is my code: pred = results.get_prediction(start=pd.to_datetime('2018-06-01'), dynamic=Fa...

stackoverflow.com

 

해결책

StackOverFlow에서 가장 많은 추천을 받은 답변.

처음엔 위 답변 때문에 내 코드가 오류를 일으키는 이유가 일간이 아니라 월간 데이터라 skip 되었다고 인식된건가? 하는 생각이 들었다. 방향을 잘못 잡은 것이다. 그러다가 내가 위에서 dropna()를 했다는 사실이 퍼뜩 떠올라서, 아래 코드처럼 데이터를 처음부터 로드해서 시도해보았더니 결과가 잘 나왔다!

 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.api as sm
from statsmodels.tsa.statespace.sarimax import SARIMAX
 
# Load the data into a Pandas DataFrame
data = sm.datasets.co2.load_pandas()
df = data.data
df=df.dropna(subset=['co2'])
 
# Convert the data to a time series format
ts = df['co2'].resample('M').mean()
 
# Split the data into training and testing sets
train = ts[:-12]
test = ts[-12:]
 
# Fit the SARIMA model to the training data
model = SARIMAX(train, order=(1,1,1), seasonal_order=(1,1,0,12))
results = model.fit()
 
# Use the SARIMA model to make predictions for the testing data
predictions = results.predict(start=test.index[0], end=test.index[-1], dynamic=False)
 
# Plot the actual and predicted values for the testing data
plt.plot(train, label="Training Data")
plt.plot(test, label="Actual Values",linewidth='3')
plt.plot(predictions, label="Predicted Values")
plt.xlabel("Year")
plt.ylabel("CO2 Concentration (ppm)")
plt.title("SARIMA Model")
plt.legend()
plt.show()
 
cs

그제서야 나의 공부친구 chatGPT 에게 물어봤더니 이런 대답을 했다.

chatGPT와 나의 대화
StackOverFlow에서 가장 많은 추천을 받은 답변.

결론 : stackoverflow는 항상 옳다. frequency가 만족해야 predict 함수를 사용할 수 있는 것이었다.. 이를 나처럼 데이터를 새로 로드하고 사용하지 않으려면 다음처럼 명시적으로 frequency를 표시해줘야 한다고 한다.

 

1
2
3
4
5
6
7
8
9
10
import pandas as pd
 
# Load your data and set the index as the first column
df = pd.read_csv('data.csv', index_col=0)
df = df.dropna()
df.index = pd.to_datetime(df.index)
 
# Set the frequency of the data, if known
ts = pd.Series(df['sales'].values, index=df.index, freq='M')
 
cs

 

이렇게 오류를 해결할 수 있었다. 고치는 데에만 몇 시간이 걸렸고, 저번 주 스터디 할 때 작업했을 때도 위와 같은 오류를 발견했다가 해결하지 못하고 건너뛴 적이 있어서 해결한 데에 뿌듯함이 들었다. 

 

앞으로도 이렇게 해결하는 데에 오래 걸린 오류가 있으면 블로그에 기록해두고자 한다. 끝!