Skip to content

predict

ies_pi_predict.api.predict

task_predict

task_predict(
    df: DataFrame,
    output_column: str,
    predict_start: datetime,
    predict_end: datetime,
    output_folder: str,
) -> dict[str, any]

task_predict is the task that calls the ies_pi_predict.predict function. It processes the resulting dataframe by saving it in a csv file. It needs to be in the global namespace to be able to be correctly picked by the Worker.

Parameters:

Name Type Description Default
df, output_column, predict_start, predict_end

same as ies_pi_predict.predict

required
output_folder str

The folder where to save the csv

required

Returns:

Name Type Description
dict dict[str, any]

{ "output_filename": str, "algorhythm": str, "score": float, "rmse": float,

dict[str, any]

}

Source code in src/ies_pi_predict/api/predict.py
def task_predict(df: pd.DataFrame, output_column: str, 
            predict_start: datetime, predict_end: datetime, 
            output_folder: str) -> dict[str, any]:
    """
    task_predict is the task that calls the `ies_pi_predict.predict` function.
    It processes the resulting dataframe by saving it in a csv file.
    It needs to be in the global namespace to be able to be correctly picked
    by the Worker.

    Args:
        df, output_column, predict_start, predict_end: same as `ies_pi_predict.predict`
        output_folder (str): The folder where to save the csv

    Returns:
        dict: {
            "output_filename": str,
            "algorhythm": str,
            "score": float,
            "rmse": float,
        }
    """    
    df, algorhythm, score, rmse = learned_predict(df, output_column, 
                                          predict_start, predict_end)

    timestamp = datetime.now().strftime(datetime_format)
    id = uuid4()
    filename = f"{timestamp}-{id}.csv"
    filepath = Path(output_folder, filename)
    df.to_csv(filepath)

    return {
        "output_filename": filename,
        "algorhythm": str(algorhythm),
        "score": score,
        "rmse": rmse,
    }