graykode

(refactor) fixed relative import paths

1 +# Copyright 2020-present Tae Hwan Jung
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +from .modeling_bart import BartForConditionalGeneration
16 +
17 +__all__ = [
18 + 'BartForConditionalGeneration'
19 +]
...\ No newline at end of file ...\ No newline at end of file
...@@ -21,7 +21,7 @@ from transformers import ( ...@@ -21,7 +21,7 @@ from transformers import (
21 PretrainedConfig, 21 PretrainedConfig,
22 PreTrainedTokenizer, 22 PreTrainedTokenizer,
23 ) 23 )
24 -from modeling_bart import BartForConditionalGeneration 24 +from .modeling_bart import BartForConditionalGeneration
25 25
26 from transformers.optimization import ( 26 from transformers.optimization import (
27 Adafactor, 27 Adafactor,
......
...@@ -41,7 +41,7 @@ from transformers.modeling_outputs import ( ...@@ -41,7 +41,7 @@ from transformers.modeling_outputs import (
41 Seq2SeqQuestionAnsweringModelOutput, 41 Seq2SeqQuestionAnsweringModelOutput,
42 Seq2SeqSequenceClassifierOutput, 42 Seq2SeqSequenceClassifierOutput,
43 ) 43 )
44 -from modeling_utils import PreTrainedModel 44 +from .modeling_utils import PreTrainedModel
45 import logging 45 import logging
46 46
47 logger = logging.getLogger(__name__) # pylint: disable=invalid-name 47 logger = logging.getLogger(__name__) # pylint: disable=invalid-name
......
...@@ -39,7 +39,7 @@ from transformers.file_utils import ( ...@@ -39,7 +39,7 @@ from transformers.file_utils import (
39 is_torch_tpu_available, 39 is_torch_tpu_available,
40 replace_return_docstrings, 40 replace_return_docstrings,
41 ) 41 )
42 -from generation_utils import GenerationMixin 42 +from .generation_utils import GenerationMixin
43 import logging 43 import logging
44 44
45 logger = logging.getLogger(__name__) # pylint: disable=invalid-name 45 logger = logging.getLogger(__name__) # pylint: disable=invalid-name
......