-
Notifications
You must be signed in to change notification settings - Fork 65
Issues with pytorch bindings in the case where structural_feasibility=False. #414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
oomcth
wants to merge
24
commits into
Simple-Robotics:devel
Choose a base branch
from
oomcth:devel
base: devel
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
5b3671e
Add a test for the pytorch interface
oomcth 10de54e
Fix batched unfeasible QP
oomcth ccfc91a
Rename cvxpy.py to cvxpy_test.py
oomcth e736b9c
Add some checks for the gradients
oomcth 0166ef0
Fix some bugs for the backward infeasible case
oomcth 1b75a6e
Merge branch 'Simple-Robotics:devel' into devel
oomcth be7f5af
Delete test/src/cvxpy_test.py
oomcth 154e858
Add files via upload
oomcth fb9c80a
Add files via upload
oomcth acb0276
now test handle with and without eq/neq
oomcth 538459e
Add back cvxpy.py
oomcth 2630af4
fixed a ruff error
oomcth af8c876
Fixed a ruff error
oomcth 03b3ee1
Added checks for gradients consistency
oomcth b52112a
Added fix for gradients
oomcth ea6f117
Gradients checks are not enough.
oomcth a25953b
Fixed some minor gradient bugs in the case where the inputs were batc…
oomcth a19542d
Fix minor error in docstrings
oomcth e507b34
changed max_iter_in value
oomcth c1be573
Delete test/src/QPLayer_test.py
oomcth 4bde1fb
Add files via upload
oomcth f505e89
Delete bindings/python/proxsuite/torch/torch_bindings.py
oomcth 1e4f6b6
Changed back max_iter_in to 100
oomcth 81ff8fb
Merge branch 'Simple-Robotics:devel' into devel
oomcth File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -91,7 +91,10 @@ def QPFunction( | |
class QPFunctionFn(Function): | ||
@staticmethod | ||
def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): | ||
nBatch = extract_nBatch(Q_, p_, A_, b_, G_, l_, u_) | ||
if len(Q_.size()) == 3: | ||
nBatch = Q_.size(0) | ||
else: | ||
nBatch = 1 | ||
Q, _ = expandParam(Q_, nBatch, 3) | ||
p, _ = expandParam(p_, nBatch, 2) | ||
G, _ = expandParam(G_, nBatch, 3) | ||
|
@@ -103,8 +106,13 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): | |
ctx.vector_of_qps = proxsuite.proxqp.dense.BatchQP() | ||
|
||
ctx.nBatch = nBatch | ||
|
||
_, nineq, nz = G.size() | ||
do_neq = True | ||
if len(G.size()) == 3 or len(G.size()) == 2: | ||
nineq, nz = G.size()[1:] | ||
else: | ||
nineq = 0 | ||
nz = Q.size()[-1] | ||
do_neq = False | ||
neq = A.size(1) if A.nelement() > 0 else 0 | ||
assert neq > 0 or nineq > 0 | ||
ctx.neq, ctx.nineq, ctx.nz = neq, nineq, nz | ||
|
@@ -134,21 +142,20 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): | |
if p[i] is not None: | ||
p__ = p[i].cpu().numpy() | ||
G__ = None | ||
if G[i] is not None: | ||
if do_neq and G[i] is not None: | ||
G__ = G[i].cpu().numpy() | ||
u__ = None | ||
if u[i] is not None: | ||
if do_neq and u[i] is not None: | ||
u__ = u[i].cpu().numpy() | ||
l__ = None | ||
if l[i] is not None: | ||
if do_neq and l[i] is not None: | ||
l__ = l[i].cpu().numpy() | ||
A__ = None | ||
if Ai is not None: | ||
A__ = Ai.cpu().numpy() | ||
b__ = None | ||
if bi is not None: | ||
b__ = bi.cpu().numpy() | ||
|
||
qp.init( | ||
H=H__, g=p__, A=A__, b=b__, C=G__, l=l__, u=u__, rho=default_rho | ||
) | ||
|
@@ -255,8 +262,22 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus): | |
class QPFunctionFn_infeas(Function): | ||
@staticmethod | ||
def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): | ||
n_in, nz = G_.size() # true double-sided inequality size | ||
nBatch = extract_nBatch(Q_, p_, A_, b_, G_, l_, u_) | ||
|
||
do_neq = True | ||
if len(G_.size()) == 3: | ||
_, n_in, nz = G_.size() | ||
elif len(G_.size()) == 2: | ||
n_in = G_.size()[-2] | ||
nz = G_.size()[-1] | ||
else: | ||
n_in = Q_.size()[-1] | ||
nz = Q_.size()[-1] | ||
do_neq = False | ||
ctx.G_size = G_.size() | ||
if len(Q_.size()) == 3: | ||
nBatch = Q_.size(0) | ||
else: | ||
nBatch = 1 | ||
|
||
Q, _ = expandParam(Q_, nBatch, 3) | ||
p, _ = expandParam(p_, nBatch, 2) | ||
|
@@ -268,32 +289,43 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): | |
|
||
h = torch.cat((-l, u), axis=1) # single-sided inequality | ||
G = torch.cat((-G, G), axis=1) # single-sided inequality | ||
|
||
_, nineq, nz = G.size() | ||
neq = A.size(1) if A.nelement() > 0 else 0 | ||
if len(G.size()) == 3: | ||
_, nineq, nz = G.size() | ||
else: | ||
nineq = 0 | ||
nz = Q.size()[-1] | ||
if len(A.size()) == 3 or len(A.size()) == 2: | ||
neq = A.size(-2) if A.nelement() > 0 else 0 | ||
else: | ||
neq = 0 | ||
assert neq > 0 or nineq > 0 | ||
ctx.neq, ctx.nineq, ctx.nz = neq, nineq, nz | ||
|
||
zhats = torch.empty((nBatch, ctx.nz), dtype=Q.dtype) | ||
nus = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype) | ||
nus_sol = torch.empty( | ||
(nBatch, n_in), dtype=Q.dtype | ||
) # double-sided inequality multiplier | ||
if do_neq: | ||
nus_sol = torch.empty( | ||
(nBatch, n_in), dtype=Q.dtype | ||
) # double-sided inequality multiplier | ||
else: | ||
nus_sol = None | ||
lams = ( | ||
torch.empty(nBatch, ctx.neq, dtype=Q.dtype) | ||
if ctx.neq > 0 | ||
else torch.empty() | ||
else torch.tensor([]) | ||
) | ||
s_e = ( | ||
torch.empty(nBatch, ctx.neq, dtype=Q.dtype) | ||
if ctx.neq > 0 | ||
else torch.empty() | ||
else torch.tensor([]) | ||
) | ||
slacks = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype) | ||
s_i = torch.empty( | ||
(nBatch, n_in), dtype=Q.dtype | ||
) # this one is of size the one of the original n_in | ||
|
||
if do_neq: | ||
s_i = torch.empty( | ||
(nBatch, n_in), dtype=Q.dtype | ||
) # this one is of size the one of the original n_in | ||
else: | ||
s_i = None | ||
vector_of_qps = proxsuite.proxqp.dense.BatchQP() | ||
|
||
ctx.cpu = os.cpu_count() | ||
|
@@ -311,18 +343,21 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): | |
qp.settings.refactor_rho_threshold = default_rho # no refactorization | ||
qp.settings.eps_abs = eps | ||
Ai, bi = (A[i], b[i]) if neq > 0 else (None, None) | ||
|
||
H__ = None | ||
if Q[i] is not None: | ||
H__ = Q[i].cpu().numpy() | ||
p__ = None | ||
if p[i] is not None: | ||
p__ = p[i].cpu().numpy() | ||
G__ = None | ||
if G[i] is not None: | ||
if do_neq and G[i] is not None: | ||
G__ = G[i].cpu().numpy() | ||
u__ = None | ||
if h[i] is not None: | ||
if do_neq and h[i] is not None: | ||
u__ = h[i].cpu().numpy() | ||
if not do_neq: | ||
l = None | ||
# l__ = None | ||
# if (l[i] is not None): | ||
# l__ = l[i].cpu().numpy() | ||
|
@@ -332,7 +367,6 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): | |
b__ = None | ||
if bi is not None: | ||
b__ = bi.cpu().numpy() | ||
|
||
qp.init(H=H__, g=p__, A=A__, b=b__, C=G__, l=l, u=u__, rho=default_rho) | ||
|
||
if proxqp_parallel: | ||
|
@@ -348,16 +382,18 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): | |
if nineq > 0: | ||
# we re-convert the solution to a double sided inequality QP | ||
slack = -h[i] + G[i] @ vector_of_qps.get(i).results.x | ||
nus_sol[i] = torch.Tensor( | ||
-vector_of_qps.get(i).results.z[:n_in] | ||
+ vector_of_qps.get(i).results.z[n_in:] | ||
) # de-projecting this one may provoke loss of information when using inexact solution | ||
if do_neq: | ||
nus_sol[i] = torch.Tensor( | ||
-vector_of_qps.get(i).results.z[:n_in] | ||
+ vector_of_qps.get(i).results.z[n_in:] | ||
) # de-projecting this one may provoke loss of information when using inexact solution | ||
nus[i] = torch.tensor(vector_of_qps.get(i).results.z) | ||
slacks[i] = slack.clone().detach() | ||
s_i[i] = torch.tensor( | ||
-vector_of_qps.get(i).results.si[:n_in] | ||
+ vector_of_qps.get(i).results.si[n_in:] | ||
) | ||
if do_neq: | ||
s_i[i] = torch.tensor( | ||
-vector_of_qps.get(i).results.si[:n_in] | ||
+ vector_of_qps.get(i).results.si[n_in:] | ||
) | ||
if neq > 0: | ||
lams[i] = torch.tensor(vector_of_qps.get(i).results.y) | ||
s_e[i] = torch.tensor(vector_of_qps.get(i).results.se) | ||
|
@@ -371,7 +407,10 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): | |
@staticmethod | ||
def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i): | ||
zhats, s_e, Q, p, G, l, u, A, b = ctx.saved_tensors | ||
nBatch = extract_nBatch(Q, p, A, b, G, l, u) | ||
if len(Q.size()) == 3: | ||
nBatch = Q.size(0) | ||
else: | ||
nBatch = 1 | ||
|
||
Q, Q_e = expandParam(Q, nBatch, 3) | ||
p, p_e = expandParam(p, nBatch, 2) | ||
|
@@ -414,7 +453,9 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i): | |
|
||
for i in range(nBatch): | ||
Q_i = Q[i].numpy() | ||
C_i = G[i].numpy() | ||
C_i = None | ||
if G is not None and G.numel() != 0: | ||
C_i = G[i].numpy() | ||
A_i = None | ||
if A is not None: | ||
if A.shape[0] != 0: | ||
|
@@ -436,16 +477,17 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i): | |
if neq > 0: | ||
kkt[:dim, dim : dim + n_eq] = A_i.transpose() | ||
kkt[dim : dim + n_eq, :dim] = A_i | ||
kkt[ | ||
dim + n_eq + n_in : dim + 2 * n_eq + n_in, dim : dim + n_eq | ||
] = -np.eye(n_eq) | ||
kkt[dim + n_eq + n_in : dim + 2 * n_eq + n_in, dim : dim + n_eq] = ( | ||
-np.eye(n_eq) | ||
) | ||
kkt[ | ||
dim + n_eq + n_in : dim + 2 * n_eq + n_in, | ||
dim + n_eq + 2 * n_in : 2 * dim + n_eq + 2 * n_in, | ||
] = A_i | ||
|
||
kkt[:dim, dim + n_eq : dim + n_eq + n_in] = C_i.transpose() | ||
kkt[dim + n_eq : dim + n_eq + n_in, :dim] = C_i | ||
if n_in > 0: | ||
kkt[:dim, dim + n_eq : dim + n_eq + n_in] = C_i.transpose() | ||
kkt[dim + n_eq : dim + n_eq + n_in, :dim] = C_i | ||
|
||
D_1_c = np.eye(n_in) # represents [s_i]_- + z_i < 0 | ||
D_1_c[P_1, P_1] = 0.0 | ||
|
@@ -485,9 +527,9 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i): | |
rhs[dim + n_eq : dim + n_eq + n_in_sol][~active_set] = dl_dnus[ | ||
i | ||
][~active_set] | ||
rhs[dim + n_eq + n_in_sol : dim + n_eq + n_in][ | ||
active_set | ||
] = -dl_dnus[i][active_set] | ||
rhs[dim + n_eq + n_in_sol : dim + n_eq + n_in][active_set] = ( | ||
-dl_dnus[i][active_set] | ||
) | ||
Comment on lines
+530
to
+532
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a file formatting change? Did it come from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its a formatting change. It comes from ruff on my computer. |
||
if dl_ds_e is not None: | ||
if dl_ds_e.shape[0] != 0: | ||
rhs[dim + n_eq + n_in : dim + 2 * n_eq + n_in] = -dl_ds_e[i] | ||
|
@@ -515,9 +557,9 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i): | |
|
||
qp.settings.primal_infeasibility_solving = True | ||
qp.settings.eps_abs = eps_backward | ||
qp.settings.max_iter = 10 | ||
qp.settings.default_rho = 1.0e-3 | ||
qp.settings.refactor_rho_threshold = 1.0e-3 | ||
qp.settings.max_iter = 1000 | ||
qp.settings.default_rho = 5.0e-5 | ||
qp.settings.refactor_rho_threshold = 5.0e-5 | ||
qp.init( | ||
H, | ||
g, | ||
|
@@ -542,13 +584,19 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i): | |
) | ||
if n_eq > 0: | ||
dlam[i] = torch.from_numpy( | ||
np.float64(vector_of_qps.get(i).results.x[dim : dim + n_eq]) | ||
vector_of_qps.get(i) | ||
.results.x[dim : dim + n_eq] | ||
.astype(np.float64) | ||
) | ||
dnu[i] = torch.from_numpy( | ||
np.float64( | ||
vector_of_qps.get(i).results.x[dim + n_eq : dim + n_eq + n_in] | ||
|
||
if dnu is not None: | ||
dnu[i] = torch.from_numpy( | ||
np.float64( | ||
vector_of_qps.get(i).results.x[ | ||
dim + n_eq : dim + n_eq + n_in | ||
] | ||
) | ||
) | ||
) | ||
dim_ = 0 | ||
if n_eq > 0: | ||
b_5[i] = torch.from_numpy( | ||
|
@@ -566,16 +614,18 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i): | |
) | ||
|
||
dps = dx | ||
dGs = ( | ||
bger(dnu.double(), zhats.double()) | ||
+ bger(ctx.nus.double(), dx.double()) | ||
+ bger(P_2_c_s_i.double(), b_6.double()) | ||
) | ||
if G_e: | ||
dGs = dGs.mean(0) | ||
dhs = -dnu | ||
if h_e: | ||
dhs = dhs.mean(0) | ||
dGs = None | ||
if dnu is not None: | ||
dGs = ( | ||
bger(dnu.double(), zhats.double()) | ||
+ bger(ctx.nus.double(), dx.double()) | ||
+ bger(P_2_c_s_i.double(), b_6.double()) | ||
) | ||
if G_e: | ||
dGs = dGs.mean(0) | ||
dhs = -dnu | ||
if h_e: | ||
dhs = dhs.mean(0) | ||
if neq > 0: | ||
dAs = ( | ||
bger(dlam.double(), zhats.double()) | ||
|
@@ -597,16 +647,36 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i): | |
if p_e: | ||
dps = dps.mean(0) | ||
|
||
grads = ( | ||
dQs, | ||
dps, | ||
dAs, | ||
dbs, | ||
dGs[n_in_sol:, :], | ||
-dhs[:n_in_sol], | ||
dhs[n_in_sol:], | ||
) | ||
|
||
if len(ctx.G_size) == 2: | ||
grads = ( | ||
dQs, | ||
dps, | ||
dAs, | ||
dbs, | ||
dGs[n_in_sol:, :], | ||
-dhs[:n_in_sol], | ||
dhs[n_in_sol:], | ||
) | ||
elif len(ctx.G_size) == 3: | ||
grads = ( | ||
dQs, | ||
dps, | ||
dAs, | ||
dbs, | ||
dGs[:, n_in_sol:, :], | ||
-dhs[:, :n_in_sol], | ||
dhs[:, n_in_sol:], | ||
) | ||
else: | ||
grads = ( | ||
dQs, | ||
dps, | ||
dAs, | ||
dbs, | ||
None, | ||
None, | ||
None, | ||
) | ||
return grads | ||
|
||
if structural_feasibility: | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this to ensure at least a 1st-order tensor? Does
torch.empty()
create a zero-dim tensor? Aren't there args for it to set its dimensions?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.empty() does not work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://docs.pytorch.org/docs/stable/generated/torch.empty.html
That means we just need to pass in arguments, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could do torch.empty(0)