(fixed) CUDA runtime error (59) : device-side assert triggered error
Showing
1 changed file
with
1 additions
and
0 deletions
... | @@ -125,6 +125,7 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -125,6 +125,7 @@ class BaseTransformer(pl.LightningModule): |
125 | ) | 125 | ) |
126 | else: | 126 | else: |
127 | self.model = model | 127 | self.model = model |
128 | + self.model.resize_token_embeddings(len(tokenizer)) | ||
128 | 129 | ||
129 | def load_hf_checkpoint(self, *args, **kwargs): | 130 | def load_hf_checkpoint(self, *args, **kwargs): |
130 | self.model = self.model_type.from_pretrained(*args, **kwargs) | 131 | self.model = self.model_type.from_pretrained(*args, **kwargs) | ... | ... |
-
Please register or login to post a comment