Skip to content

Conversation

nsbg
Copy link

@nsbg nsbg commented Aug 2, 2025

cc @BenjaminBossan

I was delayed in updating the code because I was focusing on company work, but now I'm planning to resume the project in earnest. If I have any questions about implementing the code, may I continue to ask you?

I apologize for opening a new pull request, as the previous one was closed 🥲 Thank you for your understanding.

@nsbg nsbg marked this pull request as draft August 2, 2025 05:45
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for resuming your work on KaSA.

Implementation-wise, we need to take a different approach. Right now, KaSA is just added to the normal LoRA code, but we only want to activate it if the user opts in. Therefore, it should be implemented in a separate class, something like KasaVariant, in peft/tuners/lora/variants.py. Please check how DoRA is implemented and use a similar approach, as I have detailed in my previous comment. If anything is unclear, feel free to ask.

Copy link

github-actions bot commented Sep 1, 2025

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member

gentle ping @nsbg

@nsbg
Copy link
Author

nsbg commented Sep 2, 2025

Thank you for your alert!

I spent some time looking over the KaSA paper and code to get ready for more serious work, but it does seem pretty difficult 🥲 My goal is to upload code that's ready for review before the end of September, so I'm going to try even harder.

Right now, I'm stuck at the 'Extend LoRA variant resolution' stage you mentioned. Honestly, this seems like the most important part, but it's hard for me to figure out where to start—specifically, which file and class I should work on first. Could you help me with this?

@BenjaminBossan
Copy link
Member

That's great to see, thanks for picking this back up.

Right now, I'm stuck at the 'Extend LoRA variant resolution' stage you mentioned. Honestly, this seems like the most important part, but it's hard for me to figure out where to start—specifically, which file and class I should work on first. Could you help me with this?

You're already on the right track, you added KasaLinearVariant, which is the most important step. There are definitely some changes required there, as there is some code that is only relevant for DoRA and can be removed for KaSA. But we can leave that as is for now.

Next about resolving the variants. As a first step, let's revert the changes you made to lora/layer.py and start fresh. We don't need a self.use_kasa attribute, we only have self.use_dora for backwards compatibility, as we didn't have LoRA variants when we first implemented DoRA.

Then let's look at these lines in lora.Linear:

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None
from .variants import DoraLinearVariant
return DoraLinearVariant()

Here we need to extend the functionality to add KaSA. The updated method could be something like:

    def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
        if use_dora and use_kasa:
            raise ValueError("Cannot use DoRA and KaSA at the same time, please choose only one.")

        variant = None
        if use_dora:
            from .variants import DoraLinearVariant

            variant = DoraLinearVariant()
        elif use_kasa:
            ...

        return variant

Does that make sense? Similarly, we'd have to update the resolve_lora_variant methods of other LoRA layers, depending on whether they work with KaSA or not (I'm not sure if KaSA works with Conv2d etc.).

I would suggest that you work on this as a next step, then we'll see what else needs to be done.

@nsbg
Copy link
Author

nsbg commented Sep 4, 2025

wow I really appreciate your sincere feedback. I'll read your advice carefully and then move forward 🤗

@nsbg
Copy link
Author

nsbg commented Sep 8, 2025

@BenjaminBossan I modified the code in the files below based on what you explained. Please give me feedback if there are parts that still need fixing, and then we can discuss the next steps.

1. variants.py

  • Completed updates to methods in the KasaLinearVariants class

2. layer.py

  • In the LoraLayer class, added self.use_kasa[adapter_name] = use_kasa inside the update_layer method

  • In the Linear class, added KaSA handling logic inside the get_delta_weight method

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for integrating my feedback. I gave this another review and noted the next few changes that are necessary. Please check my comments.

Apart from this, the branch is now encountering merge conflicts. Could you please bring your fork up-to-date with the remote and then merge with, or rebase on, the latest main branch from PEFT? If you have questions on how to resolve the merge conflicts, don't hesitate to ask.

Furthermore, please always run make style on your changes before pushing to make our linter happy.

More of a note for myself: Since KaSA updates the base weights of the model, we will have to take extra care to ensure that it works correctly when saving and loading the adapter.

"""
return None
if use_dora and use_kasa:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's undo the changes in this method body and return None. Instead, since this KaSA layer is implemented for Linear only, add the logic to lora.Linear.resolve_lora_variant instead.

Also, we should update the resolve_lora_variant methods of the other layer types like lora.Embedding.resolve_lora_variant to accept the use_kasa argument but raise an error if it's True. Otherwise, users may add it to non-supported layers and not notice that it doesn't actually do anything there.

Comment on lines 236 to 247
############ kasa #############
self.lora_diag[adapter_name] = nn.Parameter(torch.randn(r), requires_grad=True)

weight = self.get_base_layer().weight
dtype = weight.dtype
svd_rank = self.in_features - r
weight = weight.to(torch.float32)
U, S, Vh = torch.linalg.svd(weight.data, full_matrices=False)
U_principle, S_principle, Vh_principle = U[:, :svd_rank], S[:svd_rank], Vh[:svd_rank, :]
self.get_base_layer().weight.data = (U_principle @ torch.diag(S_principle) @ Vh_principle).to(dtype)

#########################
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this can be removed, since it's part of KasaLinearVariant.init, right?

# initialize lora_diag
module.lora_diag[adapter_name] = nn.Parameter(torch.randn(module.r[adapter_name]), requires_grad=True)

# SVD
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a reference here, so that we know the origin:
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L132

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# initialize lora_diag
module.lora_diag[adapter_name] = nn.Parameter(torch.randn(module.r[adapter_name]), requires_grad=True)

# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L132
        
# SVD

I put it in here, how is it?

Comment on lines +335 to +348
@staticmethod
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
delta_weight = module.get_delta_weight(active_adapter)
return orig_weight + delta_weight

@staticmethod
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
delta_weight = module.get_delta_weight(active_adapter)
orig_weight.data += delta_weight

@staticmethod
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
delta_weight = module.get_delta_weight(active_adapter)
return orig_weight - delta_weight
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KaSA should have an influence on the merged weights, should it not?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although this PR is closed, it seems I've incorporated everything else except for this comment (of course, you'd have to look at the code). Could you explain this question in more detail?

x = dropout(x)

# KaSA calculation
lora_output = lora_B(torch.einsum('ijk,kl->ijl', lora_A(x), diag)) * scaling
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, let's add a reference:
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L602C21-L602C110

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# KaSA calculation
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L602C21-L602C110
lora_output = lora_B(torch.einsum('ijk,kl->ijl', lora_A(x), diag)) * scaling
return result + lora_output

I inserted this near where the actual calculation logic begins, rather than just in an empty space. I think this is a bit better.

@nsbg
Copy link
Author

nsbg commented Sep 16, 2025

@BenjaminBossan oh I didn't mean to close the branch, but it seems to have closed while I was merging with the main branch. I guess I'll have to open a new PR, right? 😰

+) when I tried to sync with the main branch, I ended up discarding all my commits, so did that cause it to close?

@BenjaminBossan
Copy link
Member

oh I didn't mean to close the branch, but it seems to have closed while I was merging with the main branch. I guess I'll have to open a new PR, right? 😰

+) when I tried to sync with the main branch, I ended up discarding all my commits, so did that cause it to close?

I don't know what happened, but I could re-open the PR and there are some changes visible. Can you double check that everything looks as expected? If for some reason it's not what it's expected, you can create a new PR and push your local branch.

@nsbg
Copy link
Author

nsbg commented Sep 17, 2025

I usually handle merges in the terminal, and I suspect the pull request was closed because I accidentally wiped the commit history while using the 'Sync fork' feature on GitHub. I'll be more careful in the future. Thanks for reopening it.

I'll review the changes and open a new PR if needed. Sorry to keep bothering you with this.

@BenjaminBossan
Copy link
Member

I'll review the changes and open a new PR if needed. Sorry to keep bothering you with this.

No worries. If the diff on this PR looks good, let me know and I'll do a review. Only open a new PR if for some reason, the code here does not correspond to what it should be.

@nsbg
Copy link
Author

nsbg commented Sep 20, 2025

@BenjaminBossan I checked layer.py/variants.py and KasaLinearVariants class in variants.py was removed. I added it again and I updated file based on your minor feedback, so I think we can discuss in this PR continually.

BTW I ran make style command and got this error.

make style
ruff check --fix src tests examples docs scripts docker
process_begin: CreateProcess(NULL, ruff check --fix src tests examples docs scripts docker, ...) failed.

I ran pip install -e .[test] command in https://huggingface.co/docs/peft/install#source, but I got same error. Do I just run that command directly without needing to set up a virtual environment?

@nsbg
Copy link
Author

nsbg commented Sep 20, 2025

image

maybe make style related error was fixed. After applying this command, quite a few files have changed. Is it okay to just push them? Also, what exactly does make style do?

@BenjaminBossan
Copy link
Member

maybe make style related error was fixed. After applying this command, quite a few files have changed. Is it okay to just push them? Also, what exactly does make style do?

No, let's not push any changes to unrelated files. If make style changes unrelated files, it's often to one of these reasons:

  1. Wrong ruff version: check that v0.12.12 is installed in your virtual environment.
  2. Not picking up the config: There are some settings for ruff in the pyproject.toml file, ensure that it's there when you run make style.

@nsbg
Copy link
Author

nsbg commented Sep 27, 2025

@BenjaminBossan

Also, we should update the resolve_lora_variant methods of the other layer types like lora.Embedding.resolve_lora_variant to accept the use_kasa argument but raise an error if it's True. Otherwise, users may add it to non-supported layers and not notice that it doesn't actually do anything there.

I referred to your explanation and added the use_kasa parameter to the resolve_lora_variant method in the classes below.

  • Linear (line 702)
  • Embedding (line 934)
  • _ConvNd (line 1263)
  • Conv2d (line 1507)
  • Conv1d (line 1524)
  • Conv3d (line 1541)

The logic for raising errors in each layer hasn’t been applied yet, but I committed first to check whether adding the parameter in this way matches what you meant. Excluding the Linear class, it seems like an error should be raised when use_kasa is true in the other classes. However, I might be mistaken, so please feel free to give me feedback anytime. Also, I noticed there’s no part that calls KasaLinearVariant—should this be called inside the linear class? I’m a bit confused about this part.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates, we're making good progress here.

The logic for raising errors in each layer hasn’t been applied yet, but I committed first to check whether adding the parameter in this way matches what you meant. Excluding the Linear class, it seems like an error should be raised when use_kasa is true in the other classes.

Nice, this looks correct, please raise the error in the unsupported layers as indicated.

Also, I noticed there’s no part that calls KasaLinearVariant—should this be called inside the linear class? I’m a bit confused about this part.

Don't worry, it is being called. E.g. in lora.Linear.forward we have this code:

if active_adapter not in self.lora_variant: # vanilla LoRA
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
result = self.lora_variant[active_adapter].forward(
self,
active_adapter=active_adapter,
x=x,
result=result,
**variant_kwargs,
**kwargs,
)

So if the KaSA variant is found, KasaLinearVariant.forward will be used here. Same for the other methods.

U, S, Vh = torch.linalg.svd(weight.data, full_matrices=False)
U_principle, S_principle, Vh_principle = U[:, :svd_rank], S[:svd_rank], Vh[:svd_rank, :]
module.get_base_layer().weight.data = (U_principle @ torch.diag(S_principle) @ Vh_principle).to(dtype)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a new @staticmethod called _get_delta_weight here. This method should implement the KaSA delta weight logic from above:

            diag = torch.diag(module.lora_diag[adapter])
            output_tensor = transpose(weight_B @ diag @ weight_A, module.fan_in_fan_out) * module.scaling[adapter]

Then, the merge_safe, merge_unsafe, and unmerge methods below can call KasaLinearVariant._get_delta_weight(...).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants