Skip to content

get_q in BuildingEnvelope Disconnects Gradients #118

Open
@HarryLTS

Description

When I pass tensors with gradients into the forward function of sys (an object of class BuildingEnvelope), the gradients get removed at the point where get_q is called. This happens even if sys is instantiated with the kwargs backend='torch', requires_grad=True. This affects any system which uses BuildingEnvelope, as in the following setup, the gradients from loss computed with 'yn' cannot propagate back to policy (a blocks.MLP_bounds).

    policy_node = Node(policy, ['yn', 'D', 'UB', 'LB'], ['U'], name='policy')
    system_node = Node(sys, ['xn', 'U', 'Dhidden'], ['xn', 'yn'], name='system')

    cl_system = System([policy_node, system_node], nsteps=args.nsteps, name='cl_system')

The cause of this issue seems to be that BuildingEnvelope.get_q is wrapped by @cast_backend, which calls torch.tensor(return_tensor, dtype=torch.float32) on the tensor returned by get_q, which removes its gradient. If this line is removed, the gradients are able to propagate and the policy can be trained normally.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions