Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
305 changes: 198 additions & 107 deletions test/llm/test_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from tensordict import lazy_stack, TensorDict
from torchrl.data import History, LazyStackStorage, ReplayBuffer
from torchrl.envs.llm.transforms.kl import RetrieveLogProb
from torchrl.modules.llm import Text, TransformersWrapper, vLLMWrapper
from torchrl.modules.llm.policies.common import ChatHistory, Masks, Tokens
from torchrl.objectives.llm.grpo import MCAdvantage
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
from torchrl.modules.llm.policies.common import ChatHistory, Masks, Text, Tokens
from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage
from torchrl.objectives.llm.sft import SFTLoss

_has_transformers = importlib.util.find_spec("transformers") is not None
Expand Down Expand Up @@ -53,7 +53,7 @@ def make_silly_trajectory(n_steps=None):
rewards = [torch.randn(n_tokens, 1)]
prompt = np.random.choice(prompts)
td = TensorDict(
text=Text(prompt=prompt),
query=prompt, # MCAdvantage expects "query" key, not "text"
next=TensorDict(
reward=rewards, done=torch.zeros(1, dtype=torch.bool)
),
Expand Down Expand Up @@ -83,8 +83,158 @@ def make_silly_trajectory(n_steps=None):
assert "advantage" in s.keys()


def test_grpo():
...
# Mock infrastructure moved to conftest.py


def _mock_data_grpo(
vocab_size: int, device: torch.device | str = "cpu"
) -> TensorDict:
from transformers import AutoTokenizer

device = torch.device(device)

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
prompt = History(
role=["system", "user"],
content=["You are a useful assistant.", "What is 2+2?"],
batch_size=(2,),
device=device,
)
response = History(
role=["assistant"],
content=["2 + 2 = 4."],
batch_size=(1,),
device=device,
)
full_history = prompt.extend(response, inplace=False)
history = ChatHistory(
prompt=prompt,
response=response,
full=full_history,
device=device,
)
batch_size = 1

# Expand history to match batch size before getting tokens
history = history.expand((batch_size,))
next_history = ChatHistory(
prompt=full_history,
device=device,
)
next_history = next_history.expand((batch_size,))

# Now get tokens from the expanded history objects
tokens_full = history.to_tokens(tokenizer)
next_tokens = next_history.to_tokens(tokenizer)

# Get the actual sequence length from the tokens
# tokens_full has structure with "full" key containing the actual tokens
# We need to get the padded version to know the actual length
tokens_input_ids = tokens_full.get(
"full", as_padded_tensor=True, padding_side="left", padding_value=0
)
seq_len = tokens_input_ids.shape[-1]

# Create tensors with proper shapes
reward = torch.randn(batch_size, seq_len, 1, device=device)
done = torch.zeros(batch_size, seq_len, 1, dtype=torch.bool, device=device)
advantage = torch.randn(batch_size, seq_len, 1, device=device)
log_probs = torch.randn_like(tokens_full, dtype=torch.float32, device=device)

# Create attention mask (all ones for non-padded tokens)
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)

# Import Masks to create proper mask structure
from tensordict import MetaData
from torchrl.modules.llm.policies.common import Masks

masks = Masks(
all_attention_mask=attention_mask,
all_assistant_mask=None, # Will be computed by the wrapper
padded=MetaData(True),
device=device,
)

data = TensorDict(
{
"advantage": advantage,
"history": history,
"tokens": tokens_full % vocab_size,
"masks": masks,
"next": {
"history": next_history,
"tokens": next_tokens % vocab_size,
"reward": reward,
"done": done,
},
"log_probs": log_probs,
},
batch_size=(batch_size,),
)
return data


class TestLosses:
def test_grpo(self, mock_transformer_model):
"""Test GRPO loss computation with mock models."""
vocab_size = 1024
device = torch.device("cpu")

# Create mock model and wrap it
model = mock_transformer_model(vocab_size=vocab_size, device=device)
actor_network = TransformersWrapper(
model,
generate=False,
pad_output=True,
input_mode="history",
)

# Create loss module
loss_fn = GRPOLoss(actor_network)

# Create fake data
data = _mock_data_grpo(vocab_size=vocab_size, device=device)

# Compute loss
loss_vals = loss_fn(data)

# Assertions: Check output type and structure
from torchrl.objectives.llm.grpo import GRPOLossOutput

assert isinstance(
loss_vals, GRPOLossOutput
), f"Expected GRPOLossOutput, got {type(loss_vals)}"

# Check that all expected keys are present
assert hasattr(loss_vals, "loss_objective"), "Missing loss_objective"
assert hasattr(loss_vals, "clip_fraction"), "Missing clip_fraction"
assert hasattr(loss_vals, "kl_approx"), "Missing kl_approx"
assert hasattr(loss_vals, "ESS"), "Missing ESS"
assert hasattr(loss_vals, "entropy"), "Missing entropy"
assert hasattr(loss_vals, "loss_entropy"), "Missing loss_entropy"

# Check tensor shapes (all losses should be scalars after reduction)
assert (
loss_vals.loss_objective.shape == ()
), f"loss_objective should be scalar, got {loss_vals.loss_objective.shape}"
assert (
loss_vals.clip_fraction.shape == ()
), f"clip_fraction should be scalar, got {loss_vals.clip_fraction.shape}"
assert (
loss_vals.kl_approx.shape == ()
), f"kl_approx should be scalar, got {loss_vals.kl_approx.shape}"
assert (
loss_vals.ESS.shape == ()
), f"ESS should be scalar, got {loss_vals.ESS.shape}"

# Check that losses are finite
assert torch.isfinite(loss_vals.loss_objective), "loss_objective is not finite"
assert torch.isfinite(loss_vals.ESS), "ESS is not finite"

# Check that clip_fraction is in valid range [0, 1]
assert (
0 <= loss_vals.clip_fraction <= 1
), f"clip_fraction out of range: {loss_vals.clip_fraction}"


class TestSFT:
Expand Down Expand Up @@ -203,7 +353,7 @@ def test_sft(
assistant_only=True,
tokenizer_kwargs={"chat_template_name": "qwen"},
tokenizer=tokenizer,
log_probs_key=("ref_log_prob", "full"),
log_probs_full_key=("ref_log_probs", "full"),
)
with torch.no_grad():
# Compute ref log-probs
Expand Down Expand Up @@ -247,7 +397,7 @@ def test_sft_assistant_only(self, data):
assistant_only=True,
tokenizer_kwargs={"chat_template_name": "qwen"},
tokenizer=tokenizer,
log_probs_key=("ref_log_prob", "full"),
log_probs_full_key=("ref_log_probs", "full"),
)
td = transform(data)
assert td is data
Expand All @@ -262,10 +412,12 @@ def test_sft_assistant_only(self, data):
loss(td)


@pytest.mark.slow
@pytest.mark.integration
class TestGRPOLossIntegration:
"""Test GRPOLoss integration with the new distribution methods."""
"""Integration tests for GRPOLoss with real models (vLLM + transformers)."""

@pytest.fixture(scope="module")
@pytest.fixture(scope="class")
def transformers_instance(self):
"""Create transformers model and tokenizer for testing."""
if not _has_transformers:
Expand All @@ -277,7 +429,7 @@ def transformers_instance(self):
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer

@pytest.fixture(scope="module")
@pytest.fixture(scope="class")
def vllm_instance(self):
"""Create vLLM model and tokenizer for testing."""
if not _has_vllm:
Expand All @@ -297,102 +449,52 @@ def vllm_instance(self):
except Exception as e:
pytest.skip(f"Failed to load vLLM model: {e}")

@pytest.fixture(scope="module")
def sample_tokens(self, vllm_instance):
"""Create sample tokens for testing."""
model, tokenizer = vllm_instance
text = [
"Are you happy? Say yes or no.",
"Explain the difference between a cat and a dog. Be very detailed.",
]
tokenized = tokenizer(
text, return_tensors="pt", padding=True, padding_side="left"
)
return tokenized["input_ids"], tokenized["attention_mask"]

@pytest.fixture(scope="module")
def sample_text(self):
"""Create sample text for testing."""
return [
"Are you happy? Say yes or no.",
"Explain the difference between a cat and a dog. Be very detailed.",
]

@pytest.fixture(scope="module")
def sample_history(self):
"""Create sample conversation history for testing."""
chats = [
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Are you happy? Say yes or no."},
],
[
{
"role": "system",
"content": "You are a very helpful assistant, but more handsome.",
},
{
"role": "user",
"content": "Explain the difference between a cat and a dog. Be very detailed.",
},
],
]
return History.from_chats(chats)

@pytest.fixture(scope="module")
def sample_history_assistant(self):
"""Create sample conversation history for testing."""
chats = [
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Are you happy? Say yes or no."},
{"role": "assistant", "content": "Yes."},
],
[
{
"role": "system",
"content": "You are a very helpful assistant, but more handsome.",
},
{
"role": "user",
"content": "Explain the difference between a cat and a dog. Be very detailed.",
},
{
"role": "assistant",
"content": "A cat is a small animal that meows, while a dog is a larger animal that barks.",
},
],
]
return History.from_chats(chats)

@pytest.mark.skipif(not _has_vllm, reason="vllm not available")
@pytest.mark.parametrize("masking_strategy", ["sft", "rlhf"])
def test_grpo_loss_with_transformers(
def test_grpo_loss_with_real_models(
self,
vllm_instance,
transformers_instance,
sample_history,
sample_tokens,
masking_strategy,
):
"""Test GRPOLoss with vLLM wrapper and different masking strategies."""
"""Test GRPOLoss with vLLM generation and transformers loss computation."""
from torchrl.objectives.llm.grpo import GRPOLoss

model, tokenizer = transformers_instance
vllm_model, vllm_tokenizer = vllm_instance

# Use tokens input mode for SFT, history for RLHF/generic
# Create sample input based on masking strategy
if masking_strategy == "sft":
input_mode = "tokens"
input_ids, attention_mask = sample_tokens
# Use tokens input mode for SFT
text = [
"Are you happy? Say yes or no.",
"What is 2+2?",
]
tokenized = tokenizer(
text, return_tensors="pt", padding=True, padding_side="left"
)
input_data = {
"tokens": Tokens(prompt=input_ids),
"masks": Masks(all_attention_mask=attention_mask),
"tokens": Tokens(prompt=tokenized["input_ids"]),
"masks": Masks(all_attention_mask=tokenized["attention_mask"]),
}
input_mode = "tokens"
else:
input_mode = "history"
# Use history input mode for RLHF
chats = [
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Are you happy? Say yes or no."},
],
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 2+2?"},
],
]
sample_history = History.from_chats(chats)
input_data = {"history": ChatHistory(prompt=sample_history)}
input_mode = "history"

# Generate responses with vLLM
wrapper_gen = vLLMWrapper(
vllm_model,
tokenizer=vllm_tokenizer,
Expand All @@ -403,12 +505,11 @@ def test_grpo_loss_with_transformers(
generate_kwargs={"max_tokens": 10},
)

# Create test data with advantage and correct batch size
td = TensorDict(input_data, batch_size=(2,)).to_lazystack(0)
td = wrapper_gen(td)
# use a shape that can be broadcast
td["advantage"] = torch.randn(2, 1, 1)

# Compute loss with transformers
wrapper = TransformersWrapper(
model,
tokenizer=tokenizer,
Expand All @@ -418,23 +519,13 @@ def test_grpo_loss_with_transformers(
pad_output=True,
)

# Create GRPOLoss with specified masking strategy
loss_fn = GRPOLoss(
actor_network=wrapper,
masking_strategy=masking_strategy,
)
loss_fn = GRPOLoss(actor_network=wrapper, masking_strategy=masking_strategy)

# This should work without shape mismatch errors
try:
result = loss_fn(td)
assert result is not None
except ValueError as e:
if "Shape mismatch" in str(e):
# This is expected if the advantage shape doesn't match the log-prob shape
# due to different masking strategies
assert masking_strategy in str(e)
else:
raise
# Should successfully compute loss
result = loss_fn(td)
assert result is not None
assert hasattr(result, "loss_objective")
assert torch.isfinite(result.loss_objective)


if __name__ == "__main__":
Expand Down
Loading
Loading