tf_nndistance.cpp 12.8 KB
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
REGISTER_OP("NnDistance")
	.Input("xyz1: float32")
	.Input("xyz2: float32")
	.Output("dist1: float32")
	.Output("idx1: int32")
	.Output("dist2: float32")
	.Output("idx2: int32");
REGISTER_OP("NnDistanceGrad")
	.Input("xyz1: float32")
	.Input("xyz2: float32")
	.Input("grad_dist1: float32")
	.Input("idx1: int32")
	.Input("grad_dist2: float32")
	.Input("idx2: int32")
	.Output("grad_xyz1: float32")
	.Output("grad_xyz2: float32");
using namespace tensorflow;

static void nnsearch(int b,int n,int m,const float * xyz1,const float * xyz2,float * dist,int * idx){
	for (int i=0;i<b;i++){
		for (int j=0;j<n;j++){
			float x1=xyz1[(i*n+j)*3+0];
			float y1=xyz1[(i*n+j)*3+1];
			float z1=xyz1[(i*n+j)*3+2];
			double best=0;
			int besti=0;
			for (int k=0;k<m;k++){
				float x2=xyz2[(i*m+k)*3+0]-x1;
				float y2=xyz2[(i*m+k)*3+1]-y1;
				float z2=xyz2[(i*m+k)*3+2]-z1;
				double d=x2*x2+y2*y2+z2*z2;
				if (k==0 || d<best){
					best=d;
					besti=k;
				}
			}
			dist[i*n+j]=best;
			idx[i*n+j]=besti;
		}
	}
}

class NnDistanceOp : public OpKernel{
	public:
		explicit NnDistanceOp(OpKernelConstruction* context):OpKernel(context){}
		void Compute(OpKernelContext * context)override{
			const Tensor& xyz1_tensor=context->input(0);
			const Tensor& xyz2_tensor=context->input(1);
			OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz1 be of shape (batch,#points,3)"));
			OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz1"));
			int b=xyz1_tensor.shape().dim_size(0);
			int n=xyz1_tensor.shape().dim_size(1);
			OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz2 be of shape (batch,#points,3)"));
			OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz2"));
			int m=xyz2_tensor.shape().dim_size(1);
			OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistance expects xyz1 and xyz2 have same batch size"));
			auto xyz1_flat=xyz1_tensor.flat<float>();
			const float * xyz1=&xyz1_flat(0);
			auto xyz2_flat=xyz2_tensor.flat<float>();
			const float * xyz2=&xyz2_flat(0);
			Tensor * dist1_tensor=NULL;
			Tensor * idx1_tensor=NULL;
			Tensor * dist2_tensor=NULL;
			Tensor * idx2_tensor=NULL;
			OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n},&dist1_tensor));
			OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,n},&idx1_tensor));
			auto dist1_flat=dist1_tensor->flat<float>();
			auto idx1_flat=idx1_tensor->flat<int>();
			OP_REQUIRES_OK(context,context->allocate_output(2,TensorShape{b,m},&dist2_tensor));
			OP_REQUIRES_OK(context,context->allocate_output(3,TensorShape{b,m},&idx2_tensor));
			auto dist2_flat=dist2_tensor->flat<float>();
			auto idx2_flat=idx2_tensor->flat<int>();
			float * dist1=&(dist1_flat(0));
			int * idx1=&(idx1_flat(0));
			float * dist2=&(dist2_flat(0));
			int * idx2=&(idx2_flat(0));
			nnsearch(b,n,m,xyz1,xyz2,dist1,idx1);
			nnsearch(b,m,n,xyz2,xyz1,dist2,idx2);
		}
};
REGISTER_KERNEL_BUILDER(Name("NnDistance").Device(DEVICE_CPU), NnDistanceOp);
class NnDistanceGradOp : public OpKernel{
	public:
		explicit NnDistanceGradOp(OpKernelConstruction* context):OpKernel(context){}
		void Compute(OpKernelContext * context)override{
			const Tensor& xyz1_tensor=context->input(0);
			const Tensor& xyz2_tensor=context->input(1);
			const Tensor& grad_dist1_tensor=context->input(2);
			const Tensor& idx1_tensor=context->input(3);
			const Tensor& grad_dist2_tensor=context->input(4);
			const Tensor& idx2_tensor=context->input(5);
			OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz1 be of shape (batch,#points,3)"));
			OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz1"));
			int b=xyz1_tensor.shape().dim_size(0);
			int n=xyz1_tensor.shape().dim_size(1);
			OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz2 be of shape (batch,#points,3)"));
			OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz2"));
			int m=xyz2_tensor.shape().dim_size(1);
			OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistanceGrad expects xyz1 and xyz2 have same batch size"));
			OP_REQUIRES(context,grad_dist1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires grad_dist1 be of shape(batch,#points)"));
			OP_REQUIRES(context,idx1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires idx1 be of shape(batch,#points)"));
			OP_REQUIRES(context,grad_dist2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires grad_dist2 be of shape(batch,#points)"));
			OP_REQUIRES(context,idx2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires idx2 be of shape(batch,#points)"));
			auto xyz1_flat=xyz1_tensor.flat<float>();
			const float * xyz1=&xyz1_flat(0);
			auto xyz2_flat=xyz2_tensor.flat<float>();
			const float * xyz2=&xyz2_flat(0);
			auto idx1_flat=idx1_tensor.flat<int>();
			const int * idx1=&idx1_flat(0);
			auto idx2_flat=idx2_tensor.flat<int>();
			const int * idx2=&idx2_flat(0);
			auto grad_dist1_flat=grad_dist1_tensor.flat<float>();
			const float * grad_dist1=&grad_dist1_flat(0);
			auto grad_dist2_flat=grad_dist2_tensor.flat<float>();
			const float * grad_dist2=&grad_dist2_flat(0);
			Tensor * grad_xyz1_tensor=NULL;
			OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&grad_xyz1_tensor));
			Tensor * grad_xyz2_tensor=NULL;
			OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,m,3},&grad_xyz2_tensor));
			auto grad_xyz1_flat=grad_xyz1_tensor->flat<float>();
			float * grad_xyz1=&grad_xyz1_flat(0);
			auto grad_xyz2_flat=grad_xyz2_tensor->flat<float>();
			float * grad_xyz2=&grad_xyz2_flat(0);
			for (int i=0;i<b*n*3;i++)
				grad_xyz1[i]=0;
			for (int i=0;i<b*m*3;i++)
				grad_xyz2[i]=0;
			for (int i=0;i<b;i++){
				for (int j=0;j<n;j++){
					float x1=xyz1[(i*n+j)*3+0];
					float y1=xyz1[(i*n+j)*3+1];
					float z1=xyz1[(i*n+j)*3+2];
					int j2=idx1[i*n+j];
					float x2=xyz2[(i*m+j2)*3+0];
					float y2=xyz2[(i*m+j2)*3+1];
					float z2=xyz2[(i*m+j2)*3+2];
					float g=grad_dist1[i*n+j]*2;
					grad_xyz1[(i*n+j)*3+0]+=g*(x1-x2);
					grad_xyz1[(i*n+j)*3+1]+=g*(y1-y2);
					grad_xyz1[(i*n+j)*3+2]+=g*(z1-z2);
					grad_xyz2[(i*m+j2)*3+0]-=(g*(x1-x2));
					grad_xyz2[(i*m+j2)*3+1]-=(g*(y1-y2));
					grad_xyz2[(i*m+j2)*3+2]-=(g*(z1-z2));
				}
				for (int j=0;j<m;j++){
					float x1=xyz2[(i*m+j)*3+0];
					float y1=xyz2[(i*m+j)*3+1];
					float z1=xyz2[(i*m+j)*3+2];
					int j2=idx2[i*m+j];
					float x2=xyz1[(i*n+j2)*3+0];
					float y2=xyz1[(i*n+j2)*3+1];
					float z2=xyz1[(i*n+j2)*3+2];
					float g=grad_dist2[i*m+j]*2;
					grad_xyz2[(i*m+j)*3+0]+=g*(x1-x2);
					grad_xyz2[(i*m+j)*3+1]+=g*(y1-y2);
					grad_xyz2[(i*m+j)*3+2]+=g*(z1-z2);
					grad_xyz1[(i*n+j2)*3+0]-=(g*(x1-x2));
					grad_xyz1[(i*n+j2)*3+1]-=(g*(y1-y2));
					grad_xyz1[(i*n+j2)*3+2]-=(g*(z1-z2));
				}
			}
		}
};
REGISTER_KERNEL_BUILDER(Name("NnDistanceGrad").Device(DEVICE_CPU), NnDistanceGradOp);

void NmDistanceKernelLauncher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i);
class NnDistanceGpuOp : public OpKernel{
	public:
		explicit NnDistanceGpuOp(OpKernelConstruction* context):OpKernel(context){}
		void Compute(OpKernelContext * context)override{
			const Tensor& xyz1_tensor=context->input(0);
			const Tensor& xyz2_tensor=context->input(1);
			OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz1 be of shape (batch,#points,3)"));
			OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz1"));
			int b=xyz1_tensor.shape().dim_size(0);
			int n=xyz1_tensor.shape().dim_size(1);
			OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz2 be of shape (batch,#points,3)"));
			OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz2"));
			int m=xyz2_tensor.shape().dim_size(1);
			OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistance expects xyz1 and xyz2 have same batch size"));
			auto xyz1_flat=xyz1_tensor.flat<float>();
			const float * xyz1=&xyz1_flat(0);
			auto xyz2_flat=xyz2_tensor.flat<float>();
			const float * xyz2=&xyz2_flat(0);
			Tensor * dist1_tensor=NULL;
			Tensor * idx1_tensor=NULL;
			Tensor * dist2_tensor=NULL;
			Tensor * idx2_tensor=NULL;
			OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n},&dist1_tensor));
			OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,n},&idx1_tensor));
			auto dist1_flat=dist1_tensor->flat<float>();
			auto idx1_flat=idx1_tensor->flat<int>();
			OP_REQUIRES_OK(context,context->allocate_output(2,TensorShape{b,m},&dist2_tensor));
			OP_REQUIRES_OK(context,context->allocate_output(3,TensorShape{b,m},&idx2_tensor));
			auto dist2_flat=dist2_tensor->flat<float>();
			auto idx2_flat=idx2_tensor->flat<int>();
			float * dist1=&(dist1_flat(0));
			int * idx1=&(idx1_flat(0));
			float * dist2=&(dist2_flat(0));
			int * idx2=&(idx2_flat(0));
			NmDistanceKernelLauncher(b,n,xyz1,m,xyz2,dist1,idx1,dist2,idx2);
		}
};
REGISTER_KERNEL_BUILDER(Name("NnDistance").Device(DEVICE_GPU), NnDistanceGpuOp);

void NmDistanceGradKernelLauncher(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2);
class NnDistanceGradGpuOp : public OpKernel{
	public:
		explicit NnDistanceGradGpuOp(OpKernelConstruction* context):OpKernel(context){}
		void Compute(OpKernelContext * context)override{
			const Tensor& xyz1_tensor=context->input(0);
			const Tensor& xyz2_tensor=context->input(1);
			const Tensor& grad_dist1_tensor=context->input(2);
			const Tensor& idx1_tensor=context->input(3);
			const Tensor& grad_dist2_tensor=context->input(4);
			const Tensor& idx2_tensor=context->input(5);
			OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz1 be of shape (batch,#points,3)"));
			OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz1"));
			int b=xyz1_tensor.shape().dim_size(0);
			int n=xyz1_tensor.shape().dim_size(1);
			OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz2 be of shape (batch,#points,3)"));
			OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz2"));
			int m=xyz2_tensor.shape().dim_size(1);
			OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistanceGrad expects xyz1 and xyz2 have same batch size"));
			OP_REQUIRES(context,grad_dist1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires grad_dist1 be of shape(batch,#points)"));
			OP_REQUIRES(context,idx1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires idx1 be of shape(batch,#points)"));
			OP_REQUIRES(context,grad_dist2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires grad_dist2 be of shape(batch,#points)"));
			OP_REQUIRES(context,idx2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires idx2 be of shape(batch,#points)"));
			auto xyz1_flat=xyz1_tensor.flat<float>();
			const float * xyz1=&xyz1_flat(0);
			auto xyz2_flat=xyz2_tensor.flat<float>();
			const float * xyz2=&xyz2_flat(0);
			auto idx1_flat=idx1_tensor.flat<int>();
			const int * idx1=&idx1_flat(0);
			auto idx2_flat=idx2_tensor.flat<int>();
			const int * idx2=&idx2_flat(0);
			auto grad_dist1_flat=grad_dist1_tensor.flat<float>();
			const float * grad_dist1=&grad_dist1_flat(0);
			auto grad_dist2_flat=grad_dist2_tensor.flat<float>();
			const float * grad_dist2=&grad_dist2_flat(0);
			Tensor * grad_xyz1_tensor=NULL;
			OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&grad_xyz1_tensor));
			Tensor * grad_xyz2_tensor=NULL;
			OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,m,3},&grad_xyz2_tensor));
			auto grad_xyz1_flat=grad_xyz1_tensor->flat<float>();
			float * grad_xyz1=&grad_xyz1_flat(0);
			auto grad_xyz2_flat=grad_xyz2_tensor->flat<float>();
			float * grad_xyz2=&grad_xyz2_flat(0);
			NmDistanceGradKernelLauncher(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_dist2,idx2,grad_xyz1,grad_xyz2);
		}
};
REGISTER_KERNEL_BUILDER(Name("NnDistanceGrad").Device(DEVICE_GPU), NnDistanceGradGpuOp);