Showing
4 changed files
with
22 additions
and
3 deletions
train/__init__.py
0 → 100644
1 | +# Copyright 2020-present Tae Hwan Jung | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +from .modeling_bart import BartForConditionalGeneration | ||
16 | + | ||
17 | +__all__ = [ | ||
18 | + 'BartForConditionalGeneration' | ||
19 | +] | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -21,7 +21,7 @@ from transformers import ( | ... | @@ -21,7 +21,7 @@ from transformers import ( |
21 | PretrainedConfig, | 21 | PretrainedConfig, |
22 | PreTrainedTokenizer, | 22 | PreTrainedTokenizer, |
23 | ) | 23 | ) |
24 | -from modeling_bart import BartForConditionalGeneration | 24 | +from .modeling_bart import BartForConditionalGeneration |
25 | 25 | ||
26 | from transformers.optimization import ( | 26 | from transformers.optimization import ( |
27 | Adafactor, | 27 | Adafactor, | ... | ... |
... | @@ -41,7 +41,7 @@ from transformers.modeling_outputs import ( | ... | @@ -41,7 +41,7 @@ from transformers.modeling_outputs import ( |
41 | Seq2SeqQuestionAnsweringModelOutput, | 41 | Seq2SeqQuestionAnsweringModelOutput, |
42 | Seq2SeqSequenceClassifierOutput, | 42 | Seq2SeqSequenceClassifierOutput, |
43 | ) | 43 | ) |
44 | -from modeling_utils import PreTrainedModel | 44 | +from .modeling_utils import PreTrainedModel |
45 | import logging | 45 | import logging |
46 | 46 | ||
47 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name | 47 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ... | ... |
... | @@ -39,7 +39,7 @@ from transformers.file_utils import ( | ... | @@ -39,7 +39,7 @@ from transformers.file_utils import ( |
39 | is_torch_tpu_available, | 39 | is_torch_tpu_available, |
40 | replace_return_docstrings, | 40 | replace_return_docstrings, |
41 | ) | 41 | ) |
42 | -from generation_utils import GenerationMixin | 42 | +from .generation_utils import GenerationMixin |
43 | import logging | 43 | import logging |
44 | 44 | ||
45 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name | 45 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ... | ... |
-
Please register or login to post a comment