-
Notifications
You must be signed in to change notification settings - Fork 977
Description
Describe the issue linked to the documentation
When running a TFM where (for example) the target_ts and covariate_ts are of different dtypes (e.g. dtype(target_ts) = float32, dtype(covariate_ts) = float64), this mismatch is neither caught nor resolved by Darts and is passed onto the underlying torch package where a model-specific check_input function in torch throws ValueError: input must have the type torch.float32, got type torch.float64.
Is this expected behavior or am I doing something wrong?
Additional context
This appears to happen because dtypes of the underlying TFM are inferred only from the dtype of the target (line 446 in torch_forecasting_model.py under TorchForecastingMdodel's _init_model() function which reads dtype = self.train_sample[0].dtype where the train_sample is instantiated during that particular model's _create_model().
Potential improvement
If this is expected behavior, it feels like this should be caught earlier in the process since it's unclear why there is a dtype mismatch without knowing the train_sample mechanism of Darts.
Ideally, the TFM is set to the highest precision dtype seen from either target or covs, instate the model according to that dtype and upcast lower precision dtypes which may negatively affect performance but not the data values themselves.
I also don't recall seeing that this should be handled by the Darts package user, but I also haven't memorized all of Darts' docs (yet ;) ) and could very well be wrong.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status