PulseFocusPlatform/ppdet/modeling/architectures/meta_arch.py

45 lines
1.1 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
from ppdet.core.workspace import register
__all__ = ['BaseArch']
@register
class BaseArch(nn.Layer):
def __init__(self, data_format='NCHW'):
super(BaseArch, self).__init__()
self.data_format = data_format
def forward(self, inputs):
if self.data_format == 'NHWC':
image = inputs['image']
inputs['image'] = paddle.transpose(image, [0, 2, 3, 1])
self.inputs = inputs
self.model_arch()
if self.training:
out = self.get_loss()
else:
out = self.get_pred()
return out
def build_inputs(self, data, input_def):
inputs = {}
for i, k in enumerate(input_def):
inputs[k] = data[i]
return inputs
def model_arch(self, ):
pass
def get_loss(self, ):
raise NotImplementedError("Should implement get_loss method!")
def get_pred(self, ):
raise NotImplementedError("Should implement get_pred method!")