graykode

(fixed) CUDA runtime error (59) : device-side assert triggered error

...@@ -125,6 +125,7 @@ class BaseTransformer(pl.LightningModule): ...@@ -125,6 +125,7 @@ class BaseTransformer(pl.LightningModule):
125 ) 125 )
126 else: 126 else:
127 self.model = model 127 self.model = model
128 + self.model.resize_token_embeddings(len(tokenizer))
128 129
129 def load_hf_checkpoint(self, *args, **kwargs): 130 def load_hf_checkpoint(self, *args, **kwargs):
130 self.model = self.model_type.from_pretrained(*args, **kwargs) 131 self.model = self.model_type.from_pretrained(*args, **kwargs)
......