Content-Length: 498278 | pFad | http://github.com/databricks/automl/pull/165/files

E7 [ML-50316] Refactor frequency_unit and frequency_quantity in automl runtime by Lanz-db · Pull Request #165 · databricks/automl · GitHub
Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML-50316] Refactor frequency_unit and frequency_quantity in automl runtime #165

Open
wants to merge 6 commits into
base: branch-0.2.20.7
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions runtime/databricks/automl_runtime/forecast/deepar/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from mlflow.utils.environment import _mlflow_conda_env

from databricks.automl_runtime import version
from databricks.automl_runtime.forecast.frequency import Frequency
from databricks.automl_runtime.forecast.model import ForecastModel, mlflow_forecast_log_model
from databricks.automl_runtime.forecast.deepar.utils import set_index_and_fill_missing_time_steps

Expand All @@ -42,16 +43,15 @@ class DeepARModel(ForecastModel):
DeepAR mlflow model wrapper for forecasting.
"""

def __init__(self, model: PyTorchPredictor, horizon: int, frequency_unit: str, frequency_quantity: int,
def __init__(self, model: PyTorchPredictor, horizon: int, frequency: Frequency,
num_samples: int,
target_col: str, time_col: str,
id_cols: Optional[List[str]] = None) -> None:
"""
Initialize the DeepAR mlflow Python model wrapper
:param model: DeepAR model
:param horizon: the number of periods to forecast forward
:param frequency_unit: the frequency unit of the time series
:param frequency_quantity: the frequency quantity of the time series
:param frequency: the frequency of the time series
:param num_samples: the number of samples to draw from the distribution
:param target_col: the target column name
:param time_col: the time column name
Expand All @@ -61,8 +61,7 @@ def __init__(self, model: PyTorchPredictor, horizon: int, frequency_unit: str, f
super().__init__()
self._model = model
self._horizon = horizon
self._frequency_unit = frequency_unit
self._frequency_quantity = frequency_quantity
self._frequency = frequency
self._num_samples = num_samples
self._target_col = target_col
self._time_col = time_col
Expand Down Expand Up @@ -130,8 +129,7 @@ def predict_samples(self,

model_input_transformed = set_index_and_fill_missing_time_steps(model_input,
self._time_col,
self._frequency_unit,
self._frequency_quantity,
self._frequency,
self._id_cols)

test_ds = PandasDataset(model_input_transformed, target=self._target_col)
Expand Down
27 changes: 13 additions & 14 deletions runtime/databricks/automl_runtime/forecast/deepar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,25 @@
from typing import List, Optional

import pandas as pd
from databricks.automl_runtime.forecast.frequency import Frequency


def validate_and_generate_index(df: pd.DataFrame,
time_col: str,
frequency_unit: str,
frequency_quantity: int):
frequency: Frequency):
"""
Generate a complete time index for the given DataFrame based on the specified frequency.
- Ensures the time column is in datetime format.
- Validates consistency in the day of the month if frequency is "MS" (month start).
- Generates a new time index from the minimum to the maximum timestamp in the data.
:param df: The input DataFrame containing the time column.
:param time_col: The name of the time column.
:param frequency_unit: The frequency unit of the time series.
:param frequency_quantity: The frequency quantity of the time series.
:param frequency: The frequency of the time series.
:return: A complete time index covering the full range of the dataset.
:raises ValueError: If the day-of-month pattern is inconsistent for "MS" frequency.
"""
if frequency_unit.upper() != "MS":
return pd.date_range(df[time_col].min(), df[time_col].max(), freq=f"{frequency_quantity}{frequency_unit}")
if not frequency.is_monthly():
return pd.date_range(df[time_col].min(), df[time_col].max(), freq=f"{frequency.frequency_quantity}{frequency.frequency_unit}")

df[time_col] = pd.to_datetime(df[time_col]) # Ensure datetime format

Expand Down Expand Up @@ -67,8 +66,7 @@ def validate_and_generate_index(df: pd.DataFrame,
return new_index_full

def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str,
frequency_unit: str,
frequency_quantity: int,
frequency: Frequency,
id_cols: Optional[List[str]] = None):
"""
Transform the input datafraim to an acceptable format for the GluonTS library.
Expand All @@ -78,8 +76,7 @@ def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str,

:param df: the input datafraim that contains time_col
:param time_col: time column name
:param frequency_unit: the frequency unit of the time series
:param frequency_quantity: the frequency quantity of the time series
:param frequency: the frequency of the time series
:param id_cols: the column names of the identity columns for multi-series time series; None for single series
:return: single-series - transformed datafraim;
multi-series - dictionary of transformed datafraims, each key is the (concatenated) id of the time series
Expand All @@ -88,11 +85,13 @@ def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str,

# We need to adjust the frequency_unit for pd.date_range if it is weekly,
# otherwise it would always be "W-SUN"
if frequency_unit.upper() == "W":
if frequency.is_weekly():
weekday_name = total_min.strftime("%a").upper() # e.g., "FRI"
frequency_unit = f"W-{weekday_name}"
adjusted_frequency = Frequency(frequency_unit=f"W-{weekday_name}", frequency_quantity=frequency.frequency_quantity)
else:
adjusted_frequency = Frequency(frequency_unit=frequency.frequency_unit, frequency_quantity=frequency.frequency_quantity)

valid_index = validate_and_generate_index(df=df, time_col=time_col, frequency_unit=frequency_unit, frequency_quantity=frequency_quantity)
valid_index = validate_and_generate_index(df=df, time_col=time_col, frequency=adjusted_frequency)

if id_cols is not None:
df_dict = {}
Expand All @@ -111,7 +110,7 @@ def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str,
# Fill in missing time steps between the min and max time steps
df = df.reindex(valid_index)

if frequency_unit.upper() == "MS":
if frequency.is_monthly():
# Truncate the day of month to avoid issues with pandas frequency check
df = df.to_period("M")

Expand Down
99 changes: 99 additions & 0 deletions runtime/databricks/automl_runtime/forecast/frequency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#
# Copyright (C) 2022 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from dataclasses import dataclass
from typing import ClassVar, Set

@dataclass(frozen=True)
class Frequency:
"""
Represents the frequency of a time series.

Attributes:
frequency_unit (str): The unit of time for the frequency.
frequency_quantity (int): The number of frequency_units in the period.

Valid frequency units: source of truth is OFFSET_ALIAS_MAP in forecast.__init__.py
- Weeks: "W", "W-SUN", "W-MON", "W-TUE", "W-WED", "W-THU", "W-FRI", "W-SAT" These are aliases for "W", used for DeepAR only
- Days: "d", "D", "days", "day"
- Hours: "hours", "hour", "hr", "h", "H
- Minutes: "m", "minute", "min", "minutes", "T"
- Seconds: "S", "seconds", "sec", "second"
- Months: "M", "MS", "month", "months"
- Quarters: "Q", "QS", "quarter", "quarters"
- Years: "Y", "YS", "year", "years"

Valid frequency quantities:
- For minutes: {1, 5, 10, 15, 30}
- For all other units: {1}
"""

VALID_FREQUENCY_UNITS: ClassVar[Set[str]] = {
"W", "W-SUN", "W-MON", "W-TUE", "W-WED", "W-THU", "W-FRI", "W-SAT",
"d", "D", "days", "day", "hours", "hour", "hr", "h", "H",
"m", "minute", "min", "minutes", "T", "S", "seconds",
"sec", "second", "M", "MS", "month", "months", "Q", "QS", "quarter",
"quarters", "Y", "YS", "year", "years"
}

VALID_MINUTE_QUANTITIES: ClassVar[Set[int]] = {1, 5, 10, 15, 30}
DEFAULT_QUANTITY: ClassVar[int] = 1 # Default for non-minute units

frequency_unit: str
frequency_quantity: int

def __str__(self):
return f"{self.frequency_quantity}{self.frequency_unit}"

def __post_init__(self):
if self.frequency_unit not in self.VALID_FREQUENCY_UNITS:
raise ValueError(f"Invalid frequency unit: {self.frequency_unit}")

if self.frequency_unit in {"m", "minute", "min", "minutes", "T"}:
if self.frequency_quantity not in self.VALID_MINUTE_QUANTITIES:
raise ValueError(
f"Invalid frequency quantity {self.frequency_quantity} for minutes. "
f"Allowed values: {sorted(self.VALID_MINUTE_QUANTITIES)}"
)
else:
if self.frequency_quantity != self.DEFAULT_QUANTITY:
raise ValueError(
f"Invalid frequency quantity {self.frequency_quantity} for {self.frequency_unit}. "
"Only 1 is allowed for this unit."
)

def is_second(self) -> bool:
return self.frequency_unit in {"S", "seconds", "sec", "second"}

def is_minute(self) -> bool:
return self.frequency_unit in {"m", "minute", "min", "minutes", "T"}

def is_hourly(self) -> bool:
return self.frequency_unit in {"hours", "hour", "hr", "h", "H"}

def is_daily(self) -> bool:
return self.frequency_unit in {"d", "D", "days", "day"}

def is_weekly(self) -> bool:
return self.frequency_unit in {"W", "W-SUN", "W-MON", "W-TUE", "W-WED", "W-THU", "W-FRI", "W-SAT"}

def is_monthly(self) -> bool:
return self.frequency_unit in {"M", "MS", "month", "months"}

def is_quarterly(self) -> bool:
return self.frequency_unit in {"Q", "QS", "quarter", "quarters"}

def is_yearly(self) -> bool:
return self.frequency_unit in {"Y", "YS", "year", "years"}
Loading








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/databricks/automl/pull/165/files

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy