Toggle navigation
Toggle navigation
This project
Loading...
Sign in
Hyunji
/
CapstoneDesign2021-1
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
Hyunji
2021-06-21 19:22:46 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
4062de23a12532ef3088410b0ea4bc278a85b922
4062de23
1 parent
32c28396
base model code
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
0 deletions
lib/base_model.py
lib/base_model.py
0 → 100644
View file @
4062de2
""" base model"""
import
logging
import
numpy
as
np
import
torch.nn
as
nn
logger
=
logging
.
getLogger
()
class
Base
(
nn
.
Module
):
""" Base model with some util functions"""
def
stats
(
self
,
print_model
=
True
):
# print network model and information about parameters
logger
.
info
(
"Model info:::"
)
if
print_model
:
logger
.
info
(
self
)
count
=
0
for
i
in
self
.
parameters
():
count
+=
np
.
prod
(
i
.
shape
)
logger
.
info
(
f
"Total parameters : {count}"
)
def
to
(
self
,
*
args
,
**
kwargs
):
if
kwargs
.
get
(
"device"
):
self
.
device
=
kwargs
.
get
(
"device"
)
if
len
(
args
)
>
0
:
self
.
device
=
args
[
0
]
return
super
()
.
to
(
*
args
,
**
kwargs
)
def
forward
(
self
,
x
):
raise
NotImplementedError
()
Please
register
or
login
to post a comment