Description
When I pass tensors with gradients into the forward function of sys (an object of class BuildingEnvelope), the gradients get removed at the point where get_q is called. This happens even if sys is instantiated with the kwargs backend='torch', requires_grad=True
. This affects any system which uses BuildingEnvelope, as in the following setup, the gradients from loss computed with 'yn' cannot propagate back to policy (a blocks.MLP_bounds).
policy_node = Node(policy, ['yn', 'D', 'UB', 'LB'], ['U'], name='policy')
system_node = Node(sys, ['xn', 'U', 'Dhidden'], ['xn', 'yn'], name='system')
cl_system = System([policy_node, system_node], nsteps=args.nsteps, name='cl_system')
The cause of this issue seems to be that BuildingEnvelope.get_q is wrapped by @cast_backend
, which calls torch.tensor(return_tensor, dtype=torch.float32)
on the tensor returned by get_q, which removes its gradient. If this line is removed, the gradients are able to propagate and the policy can be trained normally.
Activity