You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

timefeatures.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from typing import List
  2. import pandas as pd
  3. from pandas.tseries import offsets
  4. from pandas.tseries.frequencies import to_offset
  5. import numpy as np
  6. import mindspore.numpy as mnp
  7. import mindspore.dataset as ds
  8. class TimeFeature:
  9. def __init__(self):
  10. pass
  11. def __call__(self, index: pd.DatetimeIndex) -> mnp.ndarray:
  12. pass
  13. def __repr__(self):
  14. return self.__class__.__name__ + "()"
  15. class SecondOfMinute(TimeFeature):
  16. """Minute of hour encoded as value between [-0.5, 0.5]"""
  17. def __call__(self, index: pd.DatetimeIndex) -> mnp.ndarray:
  18. return index.second / 59.0 - 0.5
  19. class MinuteOfHour(TimeFeature):
  20. """Minute of hour encoded as value between [-0.5, 0.5]"""
  21. def __call__(self, index: pd.DatetimeIndex) -> mnp.ndarray:
  22. return index.minute / 59.0 - 0.5
  23. class HourOfDay(TimeFeature):
  24. """Hour of day encoded as value between [-0.5, 0.5]"""
  25. def __call__(self, index: pd.DatetimeIndex) -> mnp.ndarray:
  26. return index.hour / 23.0 - 0.5
  27. class DayOfWeek(TimeFeature):
  28. """Hour of day encoded as value between [-0.5, 0.5]"""
  29. def __call__(self, index: pd.DatetimeIndex) -> mnp.ndarray:
  30. return index.dayofweek / 6.0 - 0.5
  31. class DayOfMonth(TimeFeature):
  32. """Day of month encoded as value between [-0.5, 0.5]"""
  33. def __call__(self, index: pd.DatetimeIndex) -> mnp.ndarray:
  34. return (index.day - 1) / 30.0 - 0.5
  35. class DayOfYear(TimeFeature):
  36. """Day of year encoded as value between [-0.5, 0.5]"""
  37. def __call__(self, index: pd.DatetimeIndex) -> mnp.ndarray:
  38. return (index.dayofyear - 1) / 365.0 - 0.5
  39. class MonthOfYear(TimeFeature):
  40. """Month of year encoded as value between [-0.5, 0.5]"""
  41. def __call__(self, index: pd.DatetimeIndex) -> mnp.ndarray:
  42. return (index.month - 1) / 11.0 - 0.5
  43. class WeekOfYear(TimeFeature):
  44. """Week of year encoded as value between [-0.5, 0.5]"""
  45. def __call__(self, index: pd.DatetimeIndex) -> mnp.ndarray:
  46. return (index.week - 1) / 52.0 - 0.5
  47. def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
  48. """
  49. Returns a list of time features that will be appropriate for the given frequency string.
  50. Parameters
  51. ----------
  52. freq_str
  53. Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
  54. """
  55. features_by_offsets = {
  56. offsets.YearEnd: [],
  57. offsets.QuarterEnd: [MonthOfYear],
  58. offsets.MonthEnd: [MonthOfYear],
  59. offsets.Week: [DayOfMonth, WeekOfYear],
  60. offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
  61. offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
  62. offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
  63. offsets.Minute: [
  64. MinuteOfHour,
  65. HourOfDay,
  66. DayOfWeek,
  67. DayOfMonth,
  68. DayOfYear,
  69. ],
  70. offsets.Second: [
  71. SecondOfMinute,
  72. MinuteOfHour,
  73. HourOfDay,
  74. DayOfWeek,
  75. DayOfMonth,
  76. DayOfYear,
  77. ],
  78. }
  79. offset = to_offset(freq_str)
  80. for offset_type, feature_classes in features_by_offsets.items():
  81. if isinstance(offset, offset_type):
  82. return [cls() for cls in feature_classes]
  83. supported_freq_msg = f"""
  84. Unsupported frequency {freq_str}
  85. The following frequencies are supported:
  86. Y - yearly
  87. alias: A
  88. M - monthly
  89. W - weekly
  90. D - daily
  91. B - business days
  92. H - hourly
  93. T - minutely
  94. alias: min
  95. S - secondly
  96. """
  97. raise RuntimeError(supported_freq_msg)
  98. def time_features(dates, timeenc=1, freq='h'):
  99. """
  100. > `time_features` takes in a `dates` dataframe with a 'dates' column and extracts the date down to `freq` where freq can be any of the following if `timeenc` is 0:
  101. > * m - [month]
  102. > * w - [month]
  103. > * d - [month, day, weekday]
  104. > * b - [month, day, weekday]
  105. > * h - [month, day, weekday, hour]
  106. > * t - [month, day, weekday, hour, *minute]
  107. >
  108. > If `timeenc` is 1, a similar, but different list of `freq` values are supported (all encoded between [-0.5 and 0.5]):
  109. > * Q - [month]
  110. > * M - [month]
  111. > * W - [Day of month, week of year]
  112. > * D - [Day of week, day of month, day of year]
  113. > * B - [Day of week, day of month, day of year]
  114. > * H - [Hour of day, day of week, day of month, day of year]
  115. > * T - [Minute of hour*, hour of day, day of week, day of month, day of year]
  116. > * S - [Second of minute, minute of hour, hour of day, day of week, day of month, day of year]
  117. *minute returns a number from 0-3 corresponding to the 15 minute period it falls into.
  118. """
  119. if timeenc==0:
  120. dates['month'] = dates.date.apply(lambda row:row.month,1)
  121. dates['day'] = dates.date.apply(lambda row:row.day,1)
  122. dates['weekday'] = dates.date.apply(lambda row:row.weekday(),1)
  123. dates['hour'] = dates.date.apply(lambda row:row.hour,1)
  124. dates['minute'] = dates.date.apply(lambda row:row.minute,1)
  125. dates['minute'] = dates.minute.map(lambda x:x//15)
  126. freq_map = {
  127. 'y':[],'m':['month'],'w':['month'],'d':['month','day','weekday'],
  128. 'b':['month','day','weekday'],'h':['month','day','weekday','hour'],
  129. 't':['month','day','weekday','hour','minute'],
  130. }
  131. return dates[freq_map[freq.lower()]].values
  132. if timeenc==1:
  133. dates = pd.to_datetime(dates.date.values)
  134. return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]).transpose(1,0)

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN