You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

timefeatures.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from typing import List
  2. import pandas as pd
  3. from pandas.tseries import offsets
  4. from pandas.tseries.frequencies import to_offset
  5. import mindspore.numpy as mnp
  6. import mindspore.dataset as ds
  7. class TimeFeature:
  8. def __init__(self):
  9. pass
  10. def __call__(self, index: pd.DatetimeIndex) -> mnp.ndarray:
  11. pass
  12. def __repr__(self):
  13. return self.__class__.__name__ + "()"
  14. class SecondOfMinute(TimeFeature):
  15. """Minute of hour encoded as value between [-0.5, 0.5]"""
  16. def __call__(self, index: pd.DatetimeIndex) -> mnp.ndarray:
  17. return index.second / 59.0 - 0.5
  18. class MinuteOfHour(TimeFeature):
  19. """Minute of hour encoded as value between [-0.5, 0.5]"""
  20. def __call__(self, index: pd.DatetimeIndex) -> mnp.ndarray:
  21. return index.minute / 59.0 - 0.5
  22. class HourOfDay(TimeFeature):
  23. """Hour of day encoded as value between [-0.5, 0.5]"""
  24. def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
  25. return index.hour / 23.0 - 0.5
  26. class DayOfWeek(TimeFeature):
  27. """Hour of day encoded as value between [-0.5, 0.5]"""
  28. def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
  29. return index.dayofweek / 6.0 - 0.5
  30. class DayOfMonth(TimeFeature):
  31. """Day of month encoded as value between [-0.5, 0.5]"""
  32. def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
  33. return (index.day - 1) / 30.0 - 0.5
  34. class DayOfYear(TimeFeature):
  35. """Day of year encoded as value between [-0.5, 0.5]"""
  36. def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
  37. return (index.dayofyear - 1) / 365.0 - 0.5
  38. class MonthOfYear(TimeFeature):
  39. """Month of year encoded as value between [-0.5, 0.5]"""
  40. def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
  41. return (index.month - 1) / 11.0 - 0.5
  42. class WeekOfYear(TimeFeature):
  43. """Week of year encoded as value between [-0.5, 0.5]"""
  44. def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
  45. return (index.week - 1) / 52.0 - 0.5
  46. def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
  47. """
  48. Returns a list of time features that will be appropriate for the given frequency string.
  49. Parameters
  50. ----------
  51. freq_str
  52. Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
  53. """
  54. features_by_offsets = {
  55. offsets.YearEnd: [],
  56. offsets.QuarterEnd: [MonthOfYear],
  57. offsets.MonthEnd: [MonthOfYear],
  58. offsets.Week: [DayOfMonth, WeekOfYear],
  59. offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
  60. offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
  61. offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
  62. offsets.Minute: [
  63. MinuteOfHour,
  64. HourOfDay,
  65. DayOfWeek,
  66. DayOfMonth,
  67. DayOfYear,
  68. ],
  69. offsets.Second: [
  70. SecondOfMinute,
  71. MinuteOfHour,
  72. HourOfDay,
  73. DayOfWeek,
  74. DayOfMonth,
  75. DayOfYear,
  76. ],
  77. }
  78. offset = to_offset(freq_str)
  79. for offset_type, feature_classes in features_by_offsets.items():
  80. if isinstance(offset, offset_type):
  81. return [cls() for cls in feature_classes]
  82. supported_freq_msg = f"""
  83. Unsupported frequency {freq_str}
  84. The following frequencies are supported:
  85. Y - yearly
  86. alias: A
  87. M - monthly
  88. W - weekly
  89. D - daily
  90. B - business days
  91. H - hourly
  92. T - minutely
  93. alias: min
  94. S - secondly
  95. """
  96. raise RuntimeError(supported_freq_msg)
  97. def time_features(dates, timeenc=1, freq='h'):
  98. """
  99. > `time_features` takes in a `dates` dataframe with a 'dates' column and extracts the date down to `freq` where freq can be any of the following if `timeenc` is 0:
  100. > * m - [month]
  101. > * w - [month]
  102. > * d - [month, day, weekday]
  103. > * b - [month, day, weekday]
  104. > * h - [month, day, weekday, hour]
  105. > * t - [month, day, weekday, hour, *minute]
  106. >
  107. > If `timeenc` is 1, a similar, but different list of `freq` values are supported (all encoded between [-0.5 and 0.5]):
  108. > * Q - [month]
  109. > * M - [month]
  110. > * W - [Day of month, week of year]
  111. > * D - [Day of week, day of month, day of year]
  112. > * B - [Day of week, day of month, day of year]
  113. > * H - [Hour of day, day of week, day of month, day of year]
  114. > * T - [Minute of hour*, hour of day, day of week, day of month, day of year]
  115. > * S - [Second of minute, minute of hour, hour of day, day of week, day of month, day of year]
  116. *minute returns a number from 0-3 corresponding to the 15 minute period it falls into.
  117. """
  118. if timeenc==0:
  119. dates['month'] = dates.date.apply(lambda row:row.month,1)
  120. dates['day'] = dates.date.apply(lambda row:row.day,1)
  121. dates['weekday'] = dates.date.apply(lambda row:row.weekday(),1)
  122. dates['hour'] = dates.date.apply(lambda row:row.hour,1)
  123. dates['minute'] = dates.date.apply(lambda row:row.minute,1)
  124. dates['minute'] = dates.minute.map(lambda x:x//15)
  125. freq_map = {
  126. 'y':[],'m':['month'],'w':['month'],'d':['month','day','weekday'],
  127. 'b':['month','day','weekday'],'h':['month','day','weekday','hour'],
  128. 't':['month','day','weekday','hour','minute'],
  129. }
  130. return dates[freq_map[freq.lower()]].values
  131. if timeenc==1:
  132. dates = pd.to_datetime(dates.date.values)
  133. return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]).transpose(1,0)

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN