Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 140 additions & 70 deletions bindings/python/proxsuite/torch/qplayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def QPFunction(
class QPFunctionFn(Function):
@staticmethod
def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
nBatch = extract_nBatch(Q_, p_, A_, b_, G_, l_, u_)
if len(Q_.size()) == 3:
nBatch = Q_.size(0)
else:
nBatch = 1
Q, _ = expandParam(Q_, nBatch, 3)
p, _ = expandParam(p_, nBatch, 2)
G, _ = expandParam(G_, nBatch, 3)
Expand All @@ -103,8 +106,13 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
ctx.vector_of_qps = proxsuite.proxqp.dense.BatchQP()

ctx.nBatch = nBatch

_, nineq, nz = G.size()
do_neq = True
if len(G.size()) == 3 or len(G.size()) == 2:
nineq, nz = G.size()[1:]
else:
nineq = 0
nz = Q.size()[-1]
do_neq = False
neq = A.size(1) if A.nelement() > 0 else 0
assert neq > 0 or nineq > 0
ctx.neq, ctx.nineq, ctx.nz = neq, nineq, nz
Expand Down Expand Up @@ -134,21 +142,20 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
if p[i] is not None:
p__ = p[i].cpu().numpy()
G__ = None
if G[i] is not None:
if do_neq and G[i] is not None:
G__ = G[i].cpu().numpy()
u__ = None
if u[i] is not None:
if do_neq and u[i] is not None:
u__ = u[i].cpu().numpy()
l__ = None
if l[i] is not None:
if do_neq and l[i] is not None:
l__ = l[i].cpu().numpy()
A__ = None
if Ai is not None:
A__ = Ai.cpu().numpy()
b__ = None
if bi is not None:
b__ = bi.cpu().numpy()

qp.init(
H=H__, g=p__, A=A__, b=b__, C=G__, l=l__, u=u__, rho=default_rho
)
Expand Down Expand Up @@ -255,8 +262,22 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
class QPFunctionFn_infeas(Function):
@staticmethod
def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
n_in, nz = G_.size() # true double-sided inequality size
nBatch = extract_nBatch(Q_, p_, A_, b_, G_, l_, u_)

do_neq = True
if len(G_.size()) == 3:
_, n_in, nz = G_.size()
elif len(G_.size()) == 2:
n_in = G_.size()[-2]
nz = G_.size()[-1]
else:
n_in = Q_.size()[-1]
nz = Q_.size()[-1]
do_neq = False
ctx.G_size = G_.size()
if len(Q_.size()) == 3:
nBatch = Q_.size(0)
else:
nBatch = 1

Q, _ = expandParam(Q_, nBatch, 3)
p, _ = expandParam(p_, nBatch, 2)
Expand All @@ -268,32 +289,43 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):

h = torch.cat((-l, u), axis=1) # single-sided inequality
G = torch.cat((-G, G), axis=1) # single-sided inequality

_, nineq, nz = G.size()
neq = A.size(1) if A.nelement() > 0 else 0
if len(G.size()) == 3:
_, nineq, nz = G.size()
else:
nineq = 0
nz = Q.size()[-1]
if len(A.size()) == 3 or len(A.size()) == 2:
neq = A.size(-2) if A.nelement() > 0 else 0
else:
neq = 0
assert neq > 0 or nineq > 0
ctx.neq, ctx.nineq, ctx.nz = neq, nineq, nz

zhats = torch.empty((nBatch, ctx.nz), dtype=Q.dtype)
nus = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype)
nus_sol = torch.empty(
(nBatch, n_in), dtype=Q.dtype
) # double-sided inequality multiplier
if do_neq:
nus_sol = torch.empty(
(nBatch, n_in), dtype=Q.dtype
) # double-sided inequality multiplier
else:
nus_sol = None
lams = (
torch.empty(nBatch, ctx.neq, dtype=Q.dtype)
if ctx.neq > 0
else torch.empty()
else torch.tensor([])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this to ensure at least a 1st-order tensor? Does torch.empty() create a zero-dim tensor? Aren't there args for it to set its dimensions?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.empty() does not work.

TypeError: empty() received an invalid combination of arguments - got (), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, *, torch.memory_format memory_format = None, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://docs.pytorch.org/docs/stable/generated/torch.empty.html
That means we just need to pass in arguments, no?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do torch.empty(0)

>>> torch.empty(0)
tensor([])

)
s_e = (
torch.empty(nBatch, ctx.neq, dtype=Q.dtype)
if ctx.neq > 0
else torch.empty()
else torch.tensor([])
)
slacks = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype)
s_i = torch.empty(
(nBatch, n_in), dtype=Q.dtype
) # this one is of size the one of the original n_in

if do_neq:
s_i = torch.empty(
(nBatch, n_in), dtype=Q.dtype
) # this one is of size the one of the original n_in
else:
s_i = None
vector_of_qps = proxsuite.proxqp.dense.BatchQP()

ctx.cpu = os.cpu_count()
Expand All @@ -311,18 +343,21 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
qp.settings.refactor_rho_threshold = default_rho # no refactorization
qp.settings.eps_abs = eps
Ai, bi = (A[i], b[i]) if neq > 0 else (None, None)

H__ = None
if Q[i] is not None:
H__ = Q[i].cpu().numpy()
p__ = None
if p[i] is not None:
p__ = p[i].cpu().numpy()
G__ = None
if G[i] is not None:
if do_neq and G[i] is not None:
G__ = G[i].cpu().numpy()
u__ = None
if h[i] is not None:
if do_neq and h[i] is not None:
u__ = h[i].cpu().numpy()
if not do_neq:
l = None
# l__ = None
# if (l[i] is not None):
# l__ = l[i].cpu().numpy()
Expand All @@ -332,7 +367,6 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
b__ = None
if bi is not None:
b__ = bi.cpu().numpy()

qp.init(H=H__, g=p__, A=A__, b=b__, C=G__, l=l, u=u__, rho=default_rho)

if proxqp_parallel:
Expand All @@ -348,16 +382,18 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
if nineq > 0:
# we re-convert the solution to a double sided inequality QP
slack = -h[i] + G[i] @ vector_of_qps.get(i).results.x
nus_sol[i] = torch.Tensor(
-vector_of_qps.get(i).results.z[:n_in]
+ vector_of_qps.get(i).results.z[n_in:]
) # de-projecting this one may provoke loss of information when using inexact solution
if do_neq:
nus_sol[i] = torch.Tensor(
-vector_of_qps.get(i).results.z[:n_in]
+ vector_of_qps.get(i).results.z[n_in:]
) # de-projecting this one may provoke loss of information when using inexact solution
nus[i] = torch.tensor(vector_of_qps.get(i).results.z)
slacks[i] = slack.clone().detach()
s_i[i] = torch.tensor(
-vector_of_qps.get(i).results.si[:n_in]
+ vector_of_qps.get(i).results.si[n_in:]
)
if do_neq:
s_i[i] = torch.tensor(
-vector_of_qps.get(i).results.si[:n_in]
+ vector_of_qps.get(i).results.si[n_in:]
)
if neq > 0:
lams[i] = torch.tensor(vector_of_qps.get(i).results.y)
s_e[i] = torch.tensor(vector_of_qps.get(i).results.se)
Expand All @@ -371,7 +407,10 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
@staticmethod
def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
zhats, s_e, Q, p, G, l, u, A, b = ctx.saved_tensors
nBatch = extract_nBatch(Q, p, A, b, G, l, u)
if len(Q.size()) == 3:
nBatch = Q.size(0)
else:
nBatch = 1

Q, Q_e = expandParam(Q, nBatch, 3)
p, p_e = expandParam(p, nBatch, 2)
Expand Down Expand Up @@ -414,7 +453,9 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):

for i in range(nBatch):
Q_i = Q[i].numpy()
C_i = G[i].numpy()
C_i = None
if G is not None and G.numel() != 0:
C_i = G[i].numpy()
A_i = None
if A is not None:
if A.shape[0] != 0:
Expand All @@ -436,16 +477,17 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
if neq > 0:
kkt[:dim, dim : dim + n_eq] = A_i.transpose()
kkt[dim : dim + n_eq, :dim] = A_i
kkt[
dim + n_eq + n_in : dim + 2 * n_eq + n_in, dim : dim + n_eq
] = -np.eye(n_eq)
kkt[dim + n_eq + n_in : dim + 2 * n_eq + n_in, dim : dim + n_eq] = (
-np.eye(n_eq)
)
kkt[
dim + n_eq + n_in : dim + 2 * n_eq + n_in,
dim + n_eq + 2 * n_in : 2 * dim + n_eq + 2 * n_in,
] = A_i

kkt[:dim, dim + n_eq : dim + n_eq + n_in] = C_i.transpose()
kkt[dim + n_eq : dim + n_eq + n_in, :dim] = C_i
if n_in > 0:
kkt[:dim, dim + n_eq : dim + n_eq + n_in] = C_i.transpose()
kkt[dim + n_eq : dim + n_eq + n_in, :dim] = C_i

D_1_c = np.eye(n_in) # represents [s_i]_- + z_i < 0
D_1_c[P_1, P_1] = 0.0
Expand Down Expand Up @@ -485,9 +527,9 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
rhs[dim + n_eq : dim + n_eq + n_in_sol][~active_set] = dl_dnus[
i
][~active_set]
rhs[dim + n_eq + n_in_sol : dim + n_eq + n_in][
active_set
] = -dl_dnus[i][active_set]
rhs[dim + n_eq + n_in_sol : dim + n_eq + n_in][active_set] = (
-dl_dnus[i][active_set]
)
Comment on lines +530 to +532
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a file formatting change? Did it come from pre-commit and ruff?

Copy link
Author

@oomcth oomcth Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a formatting change. It comes from ruff on my computer.

if dl_ds_e is not None:
if dl_ds_e.shape[0] != 0:
rhs[dim + n_eq + n_in : dim + 2 * n_eq + n_in] = -dl_ds_e[i]
Expand Down Expand Up @@ -515,9 +557,9 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):

qp.settings.primal_infeasibility_solving = True
qp.settings.eps_abs = eps_backward
qp.settings.max_iter = 10
qp.settings.default_rho = 1.0e-3
qp.settings.refactor_rho_threshold = 1.0e-3
qp.settings.max_iter = 1000
qp.settings.default_rho = 5.0e-5
qp.settings.refactor_rho_threshold = 5.0e-5
qp.init(
H,
g,
Expand All @@ -542,13 +584,19 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
)
if n_eq > 0:
dlam[i] = torch.from_numpy(
np.float64(vector_of_qps.get(i).results.x[dim : dim + n_eq])
vector_of_qps.get(i)
.results.x[dim : dim + n_eq]
.astype(np.float64)
)
dnu[i] = torch.from_numpy(
np.float64(
vector_of_qps.get(i).results.x[dim + n_eq : dim + n_eq + n_in]

if dnu is not None:
dnu[i] = torch.from_numpy(
np.float64(
vector_of_qps.get(i).results.x[
dim + n_eq : dim + n_eq + n_in
]
)
)
)
dim_ = 0
if n_eq > 0:
b_5[i] = torch.from_numpy(
Expand All @@ -566,16 +614,18 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
)

dps = dx
dGs = (
bger(dnu.double(), zhats.double())
+ bger(ctx.nus.double(), dx.double())
+ bger(P_2_c_s_i.double(), b_6.double())
)
if G_e:
dGs = dGs.mean(0)
dhs = -dnu
if h_e:
dhs = dhs.mean(0)
dGs = None
if dnu is not None:
dGs = (
bger(dnu.double(), zhats.double())
+ bger(ctx.nus.double(), dx.double())
+ bger(P_2_c_s_i.double(), b_6.double())
)
if G_e:
dGs = dGs.mean(0)
dhs = -dnu
if h_e:
dhs = dhs.mean(0)
if neq > 0:
dAs = (
bger(dlam.double(), zhats.double())
Expand All @@ -597,16 +647,36 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
if p_e:
dps = dps.mean(0)

grads = (
dQs,
dps,
dAs,
dbs,
dGs[n_in_sol:, :],
-dhs[:n_in_sol],
dhs[n_in_sol:],
)

if len(ctx.G_size) == 2:
grads = (
dQs,
dps,
dAs,
dbs,
dGs[n_in_sol:, :],
-dhs[:n_in_sol],
dhs[n_in_sol:],
)
elif len(ctx.G_size) == 3:
grads = (
dQs,
dps,
dAs,
dbs,
dGs[:, n_in_sol:, :],
-dhs[:, :n_in_sol],
dhs[:, n_in_sol:],
)
else:
grads = (
dQs,
dps,
dAs,
dbs,
None,
None,
None,
)
return grads

if structural_feasibility:
Expand Down
Loading