graykode

(fixed) CUDA runtime error (59) : device-side assert triggered error

......@@ -125,6 +125,7 @@ class BaseTransformer(pl.LightningModule):
)
else:
self.model = model
self.model.resize_token_embeddings(len(tokenizer))
def load_hf_checkpoint(self, *args, **kwargs):
self.model = self.model_type.from_pretrained(*args, **kwargs)
......