Skip to content

[QUESTION] Is it expected behavior that dtype variations between covariate(s) and target for TFMs fail at the PyTorch package level? #2960

@keskinoglu

Description

@keskinoglu

Describe the issue linked to the documentation
When running a TFM where (for example) the target_ts and covariate_ts are of different dtypes (e.g. dtype(target_ts) = float32, dtype(covariate_ts) = float64), this mismatch is neither caught nor resolved by Darts and is passed onto the underlying torch package where a model-specific check_input function in torch throws ValueError: input must have the type torch.float32, got type torch.float64.

Is this expected behavior or am I doing something wrong?

Additional context
This appears to happen because dtypes of the underlying TFM are inferred only from the dtype of the target (line 446 in torch_forecasting_model.py under TorchForecastingMdodel's _init_model() function which reads dtype = self.train_sample[0].dtype where the train_sample is instantiated during that particular model's _create_model().

Potential improvement
If this is expected behavior, it feels like this should be caught earlier in the process since it's unclear why there is a dtype mismatch without knowing the train_sample mechanism of Darts.

Ideally, the TFM is set to the highest precision dtype seen from either target or covs, instate the model according to that dtype and upcast lower precision dtypes which may negatively affect performance but not the data values themselves.

I also don't recall seeing that this should be handled by the Darts package user, but I also haven't memorized all of Darts' docs (yet ;) ) and could very well be wrong.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    Status

    To do

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions