Source code for drtk.interpolate

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
``drtk.interpolate`` module provides functions for differentiable interpolation of vertex
attributes across the fragments, e.i. pixels covered by the primitive.
"""

import torch as th
from drtk.utils import load_torch_ops

load_torch_ops("drtk.interpolate_ext")


[docs] @th.compiler.disable def interpolate( vert_attributes: th.Tensor, vi: th.Tensor, index_img: th.Tensor, bary_img: th.Tensor, ) -> th.Tensor: """ Performs a linear interpolation of the vertex attributes given the barycentric coordinates Args: vert_attributes (th.Tensor): vertex attribute tensor N x V x C vi (th.Tensor): face vertex index list tensor V x 3 index_img (th.Tensor): index image tensor N x H x W bary_img (th.Tensor): 3D barycentric coordinate image tensor N x 3 x H x W Returns: A tensor with interpolated vertex attributes with a shape [N, C, H, W] .. warning:: The returned tensor has only valid values for pixels which have a valid index in ``index_img``. For all other pixels, which had index ``-1`` in ``index_img``, the returned tensor will have non-zero values which should be ignored. """ return th.ops.interpolate_ext.interpolate(vert_attributes, vi, index_img, bary_img)
[docs] def interpolate_ref( vert_attributes: th.Tensor, vi: th.Tensor, index_img: th.Tensor, bary_img: th.Tensor, ) -> th.Tensor: """ A reference implementation of :func:`drtk.interpolate` in pure PyTorch. This function is used for tests only, please see :func:`drtk.interpolate` for documentation. """ # Run reference implementation in double precision to get as good reference as possible orig_dtype = vert_attributes.dtype vert_attributes = vert_attributes.double() bary_img = bary_img.double() b = vert_attributes.shape[0] iimg_clamped = index_img.clamp(min=0).long() vi_img = vi[iimg_clamped].long() v_img = th.gather( vert_attributes, 1, vi_img.view(b, -1, 1).expand(-1, -1, vert_attributes.shape[-1]), ) v_img = ( v_img.view(*vi_img.shape[:3], 3, vert_attributes.shape[-1]) .permute(0, 3, 1, 2, 4) .contiguous() ) v_img = (v_img * bary_img[..., None]).sum(dim=1) # Do the sweep of value in the range -1..1 for the `index_img == -1` region, like # in is done in the CUDA kernel. undefined_region = th.stack( [ ( th.arange(0, index_img.shape[-1], device=vert_attributes.device)[ None, ... ] .repeat(index_img.shape[-2], 1) .double() * 2.0 + 1.0 ) / index_img.shape[-1] - 1.0, ( th.arange(0, index_img.shape[-2], device=vert_attributes.device)[ ..., None ] .repeat(1, index_img.shape[-1]) .double() * 2.0 + 1.0 ) / index_img.shape[-2] - 1.0, ], dim=2, ) undefined_region = th.tile( undefined_region[None], dims=[1, 1, 1, (vert_attributes.shape[-1] + 1) // 2] )[:, :, :, : vert_attributes.shape[-1]] v_img[index_img == -1] = undefined_region.expand(index_img.shape[0], -1, -1, -1)[ index_img == -1, : ] return v_img.permute(0, 3, 1, 2).to(orig_dtype)