diff --git a/tools/fuzzing/docs/recursion_control.md b/tools/fuzzing/docs/recursion_control.md new file mode 100644 index 0000000..bd6a17f --- /dev/null +++ b/tools/fuzzing/docs/recursion_control.md @@ -0,0 +1,244 @@ +# Recursion Control in Grammar-Aware Fuzzing + +## Overview + +This document describes our dependency graph-based approach to handle recursion in ANTLR 4 grammars for the fuzzing system. The strategy ensures valid output generation while preventing infinite loops and stack overflows. + +## Our Strategy: Dependency Graph with Terminal Reachability + +### Core Approach + +1. **Build dependency graph** during grammar parsing +2. **Analyze terminal reachability** for each rule +3. **Force terminal alternatives** when hitting recursion/depth limits + +### Key Principles + +- **Rule = Graph Node**: Each grammar rule becomes a node +- **Reference = Graph Edge**: `a -> b` when rule `a` references rule `b` +- **Terminal Reachability**: Every rule must have at least one path to terminal nodes +- **Alternative Classification**: Mark which alternatives can terminate without recursion + +## Graph Structure + +### Node Definition + +```go +type GraphNode struct { + RuleName string // Rule name (e.g., "selectStmt", "expr") + HasTerminalAlternatives bool // Can reach terminal without recursion + Alternatives []Alternative // All alternatives for this rule + TerminalAlternativeIndex []int // Indices of alternatives that terminate +} + +type DependencyGraph struct { + Nodes map[string]*GraphNode +} +``` + +### Edge Types + +- **Self-Reference**: `expr -> expr` (direct recursion) +- **Cross-Reference**: `selectStmt -> whereClause` (potential indirect recursion) +- **Terminal Reference**: `expr -> NUMBER` (terminates) + +## Implementation Algorithm + +### Step 1: Build Graph During Parsing + +```go +func BuildDependencyGraph(grammar *ParsedGrammar) *DependencyGraph { + graph := &DependencyGraph{Nodes: make(map[string]*GraphNode)} + + // Create nodes for all rules + for ruleName, rule := range grammar.GetAllRules() { + node := &GraphNode{ + RuleName: ruleName, + Alternatives: rule.Alternatives, + } + graph.Nodes[ruleName] = node + } + + // Analyze each rule for terminal reachability + analyzeTerminalReachability(graph) + + return graph +} +``` + +### Step 2: Terminal Reachability Analysis + +```go +func analyzeTerminalReachability(graph *DependencyGraph) { + // Phase 1: Mark lexer rules as terminal + for _, node := range graph.Nodes { + if isLexerRule(node.RuleName) { + node.HasTerminalAlternatives = true + // All lexer alternatives are terminal + for i := range node.Alternatives { + node.TerminalAlternativeIndex = append(node.TerminalAlternativeIndex, i) + } + } + } + + // Phase 2: Propagate terminal reachability + changed := true + for changed { + changed = false + for _, node := range graph.Nodes { + if node.HasTerminalAlternatives { + continue + } + + // Check each alternative + for altIndex, alt := range node.Alternatives { + if canAlternativeTerminate(alt, graph) { + if !node.HasTerminalAlternatives { + node.HasTerminalAlternatives = true + changed = true + } + node.TerminalAlternativeIndex = append(node.TerminalAlternativeIndex, altIndex) + } + } + } + } +} + +func canAlternativeTerminate(alt Alternative, graph *DependencyGraph) bool { + for _, element := range alt.Elements { + if element.IsRule() { + referencedNode := graph.Nodes[element.RuleName] + if referencedNode == nil || !referencedNode.HasTerminalAlternatives { + return false + } + } + // Literals and lexer rules are always terminal + } + return true +} +``` + +### Step 3: Generation with Terminal Forcing + +```go +func (g *Generator) generateFromRule(ruleName string, activeRules map[string]bool, depth int) string { + node := g.dependencyGraph.Nodes[ruleName] + + // Grammar validation: ensure rule can terminate + if !node.HasTerminalAlternatives { + return "", fmt.Errorf("unsupported grammar: rule '%s' has no terminal alternatives", ruleName) + } + + // Force terminal alternatives when hitting limits + if activeRules[ruleName] || depth >= g.config.MaxDepth { + return g.forceTerminalGeneration(node) + } + + // Normal generation + activeRules[ruleName] = true + defer delete(activeRules, ruleName) + + altIndex := g.random.Intn(len(node.Alternatives)) + return g.generateFromAlternative(node.Alternatives[altIndex], activeRules, depth+1) +} + +func (g *Generator) forceTerminalGeneration(node *GraphNode) string { + // Choose randomly from terminal alternatives only + terminalIndex := g.random.Intn(len(node.TerminalAlternativeIndex)) + altIndex := node.TerminalAlternativeIndex[terminalIndex] + + // Generate with fresh context to avoid recursion + return g.generateFromAlternative(node.Alternatives[altIndex], make(map[string]bool), 0) +} +``` + +## Special Cases + +### Empty Alternatives (ε-transitions) + +```antlr +optionalClause: whereClause | /* empty */ ; +``` + +**Handling**: Create implicit ε-node for empty alternatives: +```go +// Empty alternatives are always terminal +if len(alt.Elements) == 0 { + node.TerminalAlternativeIndex = append(node.TerminalAlternativeIndex, altIndex) +} +``` + +### Quantified Elements + +```antlr +stmt: 'BEGIN' stmt* 'END'; // stmt* can be 0 occurrences +``` + +**Handling**: Quantifiers `*` and `?` create implicit terminal paths: +```go +func canElementTerminate(element Element, graph *DependencyGraph) bool { + if element.Quantifier == ZERO_MORE || element.Quantifier == OPTIONAL_Q { + return true // Can generate 0 occurrences + } + // Check if referenced rule can terminate + return graph.Nodes[element.RuleName].HasTerminalAlternatives +} +``` + +### Grammar Validation + +**Unsupported Grammars**: Rules with no terminal alternatives: +```antlr +// This will cause validation error +expr: '(' expr ')'; // No base case! +``` + +**Error Handling**: +```go +func ValidateGrammar(graph *DependencyGraph) error { + for ruleName, node := range graph.Nodes { + if !node.HasTerminalAlternatives { + return fmt.Errorf("grammar error: rule '%s' has no terminal alternatives", ruleName) + } + } + return nil +} +``` + +## Example: PostgreSQL Expression Rule + +```antlr +a_expr: a_expr '+' a_expr // Alternative 0: NON-TERMINAL (recursive) + | a_expr '*' a_expr // Alternative 1: NON-TERMINAL (recursive) + | '(' a_expr ')' // Alternative 2: NON-TERMINAL (depends on a_expr) + | c_expr // Alternative 3: TERMINAL (if c_expr terminates) + ; + +c_expr: columnref // Alternative 0: TERMINAL (lexer rule) + | '(' a_expr ')' // Alternative 1: NON-TERMINAL (recursive) + ; + +columnref: IDENTIFIER; // TERMINAL (lexer rule) +``` + +**Analysis Result**: +```go +a_expr.HasTerminalAlternatives = true +a_expr.TerminalAlternativeIndex = [3] // Only c_expr alternative + +c_expr.HasTerminalAlternatives = true +c_expr.TerminalAlternativeIndex = [0] // Only columnref alternative +``` + +**Generation Behavior**: +- **Normal case**: Choose any alternative randomly +- **Recursion/MaxDepth**: Force choice from `TerminalAlternativeIndex` only +- **Result**: Always generates valid expressions without stack overflow + +## Benefits + +1. **No Stack Overflow**: Guaranteed termination via terminal forcing +2. **Valid Output**: No placeholders, always generates parseable content +3. **Grammar Coverage**: Supports all ANTLR 4 constructs including quantifiers +4. **Early Validation**: Detects unsupported grammars during initialization +5. **Efficient**: O(1) lookup for terminal alternatives during generation \ No newline at end of file diff --git a/tools/fuzzing/internal/generator/generator.go b/tools/fuzzing/internal/generator/generator.go index 9eb2ea6..2fbff97 100644 --- a/tools/fuzzing/internal/generator/generator.go +++ b/tools/fuzzing/internal/generator/generator.go @@ -12,9 +12,10 @@ import ( // Generator handles the fuzzing logic type Generator struct { - config *config.Config - random *rand.Rand - grammar *grammar.ParsedGrammar + config *config.Config + random *rand.Rand + grammar *grammar.ParsedGrammar + dependencyGraph *grammar.DependencyGraph } // WorkItem represents a unit of work in the generation stack @@ -27,35 +28,44 @@ type WorkItem struct { // New creates a new generator with the given configuration func New(cfg *config.Config) *Generator { return &Generator{ - config: cfg, - random: rand.New(rand.NewSource(cfg.Seed)), - grammar: nil, + config: cfg, + random: rand.New(rand.NewSource(cfg.Seed)), + grammar: nil, + dependencyGraph: nil, } } // Generate produces the specified number of queries func (g *Generator) Generate() error { fmt.Println("Initializing grammar parser...") - + // Parse and merge all grammar files into a single grammar var err error g.grammar, err = grammar.ParseAndMergeGrammarFiles(g.config.GrammarFiles) if err != nil { return errors.Wrap(err, "failed to parse and merge grammar files") } - + fmt.Printf("Parsed and merged %d grammar files into single grammar\n", len(g.config.GrammarFiles)) + // Set up dependency graph + g.dependencyGraph = g.grammar.GetDependencyGraph() + + // Validate grammar has terminal alternatives (non-fatal warning) + if err := g.grammar.ValidateGrammar(); err != nil { + fmt.Printf("Grammar validation warning: %v\n", err) + } + // Validate start rule exists if g.grammar.GetRule(g.config.StartRule) == nil { return errors.Errorf("start rule '%s' not found in merged grammar", g.config.StartRule) } fmt.Printf("Generating %d queries from rule '%s'...\n", g.config.Count, g.config.StartRule) - + // Generate queries for i := 0; i < g.config.Count; i++ { - query := g.generateQuery(i + 1) + query := g.generateQuery() fmt.Printf("Query %d: %s\n", i+1, query) } @@ -67,40 +77,65 @@ func (g *Generator) getRule(ruleName string) *grammar.Rule { return g.grammar.GetRule(ruleName) } - // generateQuery creates a single query using grammar rules -func (g *Generator) generateQuery(index int) string { - // Start generation from the specified start rule with no recursion limit for now - result := g.generateFromRule(g.config.StartRule, 0) - return result +func (g *Generator) generateQuery() string { + // Start with no SCC context and 0 recursion depth + return g.generateFromRuleWithSCC(g.config.StartRule, grammar.NoSCC, 0) } -// generateFromRule generates text from a grammar rule -func (g *Generator) generateFromRule(ruleName string, currentDepth int) string { - // Check depth limit to prevent infinite recursion - if currentDepth >= g.config.MaxDepth { - return fmt.Sprintf("<%s_MAX_DEPTH>", ruleName) - } +// generateFromRule is a wrapper for backward compatibility +func (g *Generator) generateFromRule(ruleName string, depth int) string { + // For backward compatibility, treat depth as recursion depth + return g.generateFromRuleWithSCC(ruleName, grammar.NoSCC, depth) +} - // Get the rule +// generateFromRuleWithSCC generates text from a grammar rule tracking SCC-based recursion +func (g *Generator) generateFromRuleWithSCC(ruleName string, currentSCCID int, recursionDepth int) string { + // Get the rule and its SCC info rule := g.getRule(ruleName) if rule == nil { - // If rule not found, return placeholder return fmt.Sprintf("<%s>", ruleName) } - // Select a random alternative - if len(rule.Alternatives) == 0 { + node := g.dependencyGraph.GetNode(ruleName) + if node == nil { return fmt.Sprintf("<%s>", ruleName) } - + + // Determine the new recursion depth + // Only increment if we're moving within the same SCC (actual recursion) + newRecursionDepth := recursionDepth + if currentSCCID != grammar.NoSCC && node.SCCID == currentSCCID && node.IsRecursive { + // We're recursing within the same SCC + newRecursionDepth = recursionDepth + 1 + + // Check recursion depth limit + if newRecursionDepth >= g.config.MaxDepth { + return g.generateTerminalFallback(ruleName) + } + } else if node.IsRecursive { + // Entering a new recursive SCC, reset recursion depth to 0 + newRecursionDepth = 0 + } + + // Update current SCC context for recursive rules + newSCCID := currentSCCID + if node.IsRecursive { + newSCCID = node.SCCID + } + + if len(rule.Alternatives) == 0 { + return "" + } + + // Select a random alternative altIndex := g.random.Intn(len(rule.Alternatives)) alternative := rule.Alternatives[altIndex] // Generate from all elements in the alternative var result []string for _, element := range alternative.Elements { - elementResult := g.generateFromElement(&element, currentDepth) + elementResult := g.generateFromElementWithSCC(&element, newSCCID, newRecursionDepth) if elementResult != "" { result = append(result, elementResult) } @@ -109,37 +144,46 @@ func (g *Generator) generateFromRule(ruleName string, currentDepth int) string { // Format output based on configuration switch g.config.OutputFormat { case config.CompactOutput: - // Clean, readable output without verbose comments (default) return joinWithSpaces(result) case config.VerboseOutput: - // Full grammar rule traversal with comments return fmt.Sprintf("/* %s */ %s", ruleName, joinWithSpaces(result)) default: - // Default to compact return joinWithSpaces(result) } } -// generateFromElement generates text from a single grammar element -func (g *Generator) generateFromElement(element *grammar.Element, currentDepth int) string { +// generateTerminalFallback generates a simple fallback when recursion depth is exceeded +func (g *Generator) generateTerminalFallback(ruleName string) string { + // For recursive rules that hit depth limit, generate simple fallback + return generateSimpleFallback(ruleName) +} + +// SetGrammarForTesting sets the grammar for testing purposes +func (g *Generator) SetGrammarForTesting(grammar *grammar.ParsedGrammar) { + g.grammar = grammar + g.dependencyGraph = grammar.GetDependencyGraph() +} + +// generateFromElementWithSCC generates text from a single grammar element with SCC tracking +func (g *Generator) generateFromElementWithSCC(element *grammar.Element, currentSCCID int, recursionDepth int) string { // Handle optional elements if element.IsOptional() && g.random.Float64() > g.config.OptionalProb { - return "" // Skip optional element + return "" } // Handle quantified elements if element.IsQuantified() { - return g.generateQuantified(element, currentDepth) + return g.generateQuantifiedWithSCC(element, currentSCCID, recursionDepth) } // Generate single element if element.IsRule() { if refValue, ok := element.Value.(grammar.ReferenceValue); ok { - return g.generateFromRuleOrToken(refValue.Name, currentDepth+1) + return g.generateFromRuleOrTokenWithSCC(refValue.Name, currentSCCID, recursionDepth) } else if blockValue, ok := element.Value.(grammar.BlockValue); ok { - return g.generateFromBlock(blockValue, currentDepth) + return g.generateFromBlockWithSCC(blockValue, currentSCCID, recursionDepth) } - return g.generateFromRuleOrToken(element.Value.String(), currentDepth+1) + return g.generateFromRuleOrTokenWithSCC(element.Value.String(), currentSCCID, recursionDepth) } else if element.IsTerminal() { if litValue, ok := element.Value.(grammar.LiteralValue); ok { return cleanLiteral(litValue.Text) @@ -150,85 +194,8 @@ func (g *Generator) generateFromElement(element *grammar.Element, currentDepth i return element.Value.String() } -// generateQuantified handles quantified elements (* +) -func (g *Generator) generateQuantified(element *grammar.Element, currentDepth int) string { - var count int - - // Use fixed count if specified, otherwise use random count - if g.config.QuantifierCount > 0 { - count = g.config.QuantifierCount - } else { - switch element.Quantifier { - case grammar.ZERO_MORE: // * - count = g.random.Intn(g.config.MaxQuantifier + 1) // 0 to MaxQuantifier - case grammar.ONE_MORE: // + - count = 1 + g.random.Intn(g.config.MaxQuantifier) // 1 to MaxQuantifier - default: - count = 1 - } - } - - var results []string - for i := 0; i < count; i++ { - if element.IsRule() { - if refValue, ok := element.Value.(grammar.ReferenceValue); ok { - result := g.generateFromRuleOrToken(refValue.Name, currentDepth+1) - results = append(results, result) - } else if blockValue, ok := element.Value.(grammar.BlockValue); ok { - result := g.generateFromBlock(blockValue, currentDepth+1) - results = append(results, result) - } else { - result := g.generateFromRuleOrToken(element.Value.String(), currentDepth+1) - results = append(results, result) - } - } else if element.IsTerminal() { - if litValue, ok := element.Value.(grammar.LiteralValue); ok { - results = append(results, cleanLiteral(litValue.Text)) - } else { - results = append(results, cleanLiteral(element.Value.String())) - } - } - } - - return joinWithSpaces(results) -} - -// generateFromBlock generates content from a block value -func (g *Generator) generateFromBlock(blockValue grammar.BlockValue, currentDepth int) string { - if len(blockValue.Alternatives) == 0 { - return "" - } - - // Select a random alternative from the block - altIndex := g.random.Intn(len(blockValue.Alternatives)) - alternative := blockValue.Alternatives[altIndex] - - // Generate from all elements in the selected alternative - var result []string - for _, element := range alternative.Elements { - elementResult := g.generateFromElement(&element, currentDepth) - if elementResult != "" { - result = append(result, elementResult) - } - } - - return joinWithSpaces(result) -} - - -// generateFromRuleOrToken generates from a rule using standard rule-based generation -func (g *Generator) generateFromRuleOrToken(ruleName string, currentDepth int) string { - // Check if this is a lexer rule and generate concrete token - if rule := g.grammar.GetRule(ruleName); rule != nil && rule.IsLexer { - return g.generateConcreteToken(ruleName) - } - - // Otherwise expand as parser rule - return g.generateFromRule(ruleName, currentDepth) -} - // generateConcreteToken generates concrete tokens by expanding lexer rules -func (g *Generator) generateConcreteToken(ruleName string) string { +func (g *Generator) generateConcreteToken(ruleName string, depth int) string { // Get the lexer rule rule := g.grammar.GetRule(ruleName) if rule == nil || !rule.IsLexer { @@ -237,11 +204,17 @@ func (g *Generator) generateConcreteToken(ruleName string) string { // For lexer rules, we need to expand them but generate concrete characters // at the terminal level (character sets, literals, etc.) - return g.generateFromLexerRule(rule, 0) + return g.generateFromLexerRule(rule, depth) } // generateFromLexerRule generates content from a lexer rule func (g *Generator) generateFromLexerRule(rule *grammar.Rule, currentDepth int) string { + // Check recursion depth for lexer rules too + node := g.dependencyGraph.GetNode(rule.Name) + if node != nil && node.IsRecursive && currentDepth >= g.config.MaxDepth { + return generateSimpleFallback(rule.Name) + } + if len(rule.Alternatives) == 0 { return "" } @@ -253,7 +226,7 @@ func (g *Generator) generateFromLexerRule(rule *grammar.Rule, currentDepth int) // Generate from all elements in the alternative var result []string for _, element := range alternative.Elements { - elementResult := g.generateFromLexerElement(&element, currentDepth) + elementResult := g.generateFromLexerElement(&element, currentDepth+1) if elementResult != "" { result = append(result, elementResult) } @@ -279,10 +252,10 @@ func (g *Generator) generateFromLexerElement(element *grammar.Element, currentDe if refValue, ok := element.Value.(grammar.ReferenceValue); ok { // Check if referenced rule is lexer or parser if referencedRule := g.grammar.GetRule(refValue.Name); referencedRule != nil && referencedRule.IsLexer { - return g.generateFromLexerRule(referencedRule, currentDepth+1) + return g.generateFromLexerRule(referencedRule, currentDepth) } else { // Parser rule - shouldn't happen in lexer context, but handle it - return g.generateFromRule(refValue.Name, currentDepth+1) + return g.generateFromRule(refValue.Name, currentDepth) } } else if blockValue, ok := element.Value.(grammar.BlockValue); ok { return g.generateFromLexerBlock(blockValue, currentDepth) @@ -301,7 +274,7 @@ func (g *Generator) generateFromLexerElement(element *grammar.Element, currentDe // generateQuantifiedLexer handles quantified lexer elements func (g *Generator) generateQuantifiedLexer(element *grammar.Element, currentDepth int) string { var count int - + // Use fixed count if specified, otherwise use random count if g.config.QuantifierCount > 0 { count = g.config.QuantifierCount @@ -321,7 +294,7 @@ func (g *Generator) generateQuantifiedLexer(element *grammar.Element, currentDep result := g.generateFromLexerElement(&grammar.Element{ Value: element.Value, Quantifier: grammar.NONE, // Remove quantifier for individual generation - }, currentDepth+1) + }, currentDepth) if result != "" { results = append(results, result) } @@ -356,16 +329,16 @@ func (g *Generator) generateFromLexerBlock(blockValue grammar.BlockValue, curren func (g *Generator) generateFromLiteral(literal string) string { // Handle character sets like ~[\u0000"] or [a-zA-Z_] if strings.HasPrefix(literal, "~[") && strings.HasSuffix(literal, "]") { - return g.generateFromNegatedSet(literal[2 : len(literal)-1]) + return g.generateFromNegatedSet() } else if strings.HasPrefix(literal, "[") && strings.HasSuffix(literal, "]") { return g.generateFromCharacterSet(literal[1 : len(literal)-1]) } - + // Handle string literals if strings.HasPrefix(literal, "'") && strings.HasSuffix(literal, "'") && len(literal) >= 2 { return literal[1 : len(literal)-1] // Remove quotes } - + // Handle special escape sequences switch literal { case "\\r": @@ -381,7 +354,7 @@ func (g *Generator) generateFromLiteral(literal string) string { case "\\\\": return "\\" } - + // Return as-is for other cases return literal } @@ -389,7 +362,7 @@ func (g *Generator) generateFromLiteral(literal string) string { // generateFromCharacterSet generates a random character from a character set like [a-zA-Z_] func (g *Generator) generateFromCharacterSet(charset string) string { chars := []rune{} - + // Simple character set expansion - handle ranges like a-z, A-Z, 0-9 i := 0 for i < len(charset) { @@ -407,25 +380,24 @@ func (g *Generator) generateFromCharacterSet(charset string) string { i++ } } - + if len(chars) == 0 { return "x" // Fallback } - + return string(chars[g.random.Intn(len(chars))]) } // generateFromNegatedSet generates a character NOT in the specified set -func (g *Generator) generateFromNegatedSet(negatedSet string) string { +func (g *Generator) generateFromNegatedSet() string { // For simplicity, generate common safe characters that are typically not in negated sets safeChars := []string{"a", "b", "c", "x", "y", "z", "_", "1", "2", "3"} - + // TODO: Implement proper negated set handling by expanding the set and excluding those characters // For now, just return a safe character return safeChars[g.random.Intn(len(safeChars))] } - // cleanLiteral removes quotes from literal strings func cleanLiteral(literal string) string { // Remove single quotes from literals like 'SELECT' @@ -457,10 +429,100 @@ func joinStrings(strs []string, sep string) string { if len(strs) == 1 { return strs[0] } - + result := strs[0] for i := 1; i < len(strs); i++ { result += sep + strs[i] } return result -} \ No newline at end of file +} + +// generateQuantifiedWithSCC handles quantified elements with SCC tracking +func (g *Generator) generateQuantifiedWithSCC(element *grammar.Element, currentSCCID int, recursionDepth int) string { + var count int + + if g.config.QuantifierCount > 0 { + count = g.config.QuantifierCount + } else { + switch element.Quantifier { + case grammar.ZERO_MORE: // * + count = g.random.Intn(g.config.MaxQuantifier + 1) + case grammar.ONE_MORE: // + + count = 1 + g.random.Intn(g.config.MaxQuantifier) + default: + count = 1 + } + } + + var results []string + for i := 0; i < count; i++ { + if element.IsRule() { + if refValue, ok := element.Value.(grammar.ReferenceValue); ok { + result := g.generateFromRuleOrTokenWithSCC(refValue.Name, currentSCCID, recursionDepth) + results = append(results, result) + } else if blockValue, ok := element.Value.(grammar.BlockValue); ok { + result := g.generateFromBlockWithSCC(blockValue, currentSCCID, recursionDepth) + results = append(results, result) + } else { + result := g.generateFromRuleOrTokenWithSCC(element.Value.String(), currentSCCID, recursionDepth) + results = append(results, result) + } + } else if element.IsTerminal() { + if litValue, ok := element.Value.(grammar.LiteralValue); ok { + results = append(results, cleanLiteral(litValue.Text)) + } else { + results = append(results, cleanLiteral(element.Value.String())) + } + } + } + + return joinWithSpaces(results) +} + +// generateFromBlockWithSCC generates content from a block value with SCC tracking +func (g *Generator) generateFromBlockWithSCC(blockValue grammar.BlockValue, currentSCCID int, recursionDepth int) string { + if len(blockValue.Alternatives) == 0 { + return "" + } + + altIndex := g.random.Intn(len(blockValue.Alternatives)) + alternative := blockValue.Alternatives[altIndex] + + var result []string + for _, element := range alternative.Elements { + elementResult := g.generateFromElementWithSCC(&element, currentSCCID, recursionDepth) + if elementResult != "" { + result = append(result, elementResult) + } + } + + return joinWithSpaces(result) +} + +// generateFromRuleOrTokenWithSCC generates from a rule or token with SCC tracking +func (g *Generator) generateFromRuleOrTokenWithSCC(ruleName string, currentSCCID int, recursionDepth int) string { + if rule := g.grammar.GetRule(ruleName); rule != nil && rule.IsLexer { + // Lexer rules don't participate in SCC recursion tracking + return g.generateConcreteToken(ruleName, 0) + } + return g.generateFromRuleWithSCC(ruleName, currentSCCID, recursionDepth) +} + +// generateSimpleFallback generates a simple fallback value based on rule name patterns +func generateSimpleFallback(ruleName string) string { + // Generate context-appropriate fallbacks + ruleLower := strings.ToLower(ruleName) + + if strings.Contains(ruleLower, "string") || strings.Contains(ruleLower, "constant") { + return "'fallback'" + } else if strings.Contains(ruleLower, "expr") || strings.Contains(ruleLower, "expression") { + return "1" + } else if strings.Contains(ruleLower, "name") || strings.Contains(ruleLower, "id") { + return "col1" + } else if strings.Contains(ruleLower, "number") || strings.Contains(ruleLower, "numeric") { + return "1" + } else { + // Generic fallback + return "1" + } +} diff --git a/tools/fuzzing/internal/grammar/dependency.go b/tools/fuzzing/internal/grammar/dependency.go new file mode 100644 index 0000000..0160e61 --- /dev/null +++ b/tools/fuzzing/internal/grammar/dependency.go @@ -0,0 +1,463 @@ +package grammar + +import ( + "fmt" +) + +const ( + // NoSCC indicates a node is not part of any SCC or SCC not yet computed + NoSCC = -1 +) + +// DependencyGraph represents the dependency relationships between grammar rules +type DependencyGraph struct { + Nodes map[string]*GraphNode + Edges map[string][]string // Adjacency list: rule -> referenced rules + SCCs [][]string // List of SCCs (each SCC is a list of rule names) + SCCLookup map[string]int // Rule name -> SCC ID lookup map +} + +// GraphNode represents a single rule in the dependency graph +type GraphNode struct { + RuleName string // Rule name (e.g., "selectStmt", "expr") + Alternatives []Alternative // All alternatives for this rule + IsLexer bool // Whether this is a lexer rule + SCCID int // Which SCC this node belongs to (NoSCC if not computed) + SCCSize int // Size of the SCC this node belongs to + IsRecursive bool // True if part of a recursive SCC (size > 1 or self-loop) +} + +// NewDependencyGraph creates a new dependency graph +func NewDependencyGraph() *DependencyGraph { + return &DependencyGraph{ + Nodes: make(map[string]*GraphNode), + Edges: make(map[string][]string), + SCCs: [][]string{}, + SCCLookup: make(map[string]int), + } +} + +// AddNode adds a rule node to the dependency graph +func (g *DependencyGraph) AddNode(ruleName string, rule *Rule) { + node := &GraphNode{ + RuleName: ruleName, + Alternatives: rule.Alternatives, + IsLexer: rule.IsLexer, + SCCID: NoSCC, + SCCSize: 0, + IsRecursive: false, + } + g.Nodes[ruleName] = node + + // Don't build edges here because this rule may reference other rules that + // haven't been added yet (forward references). Edges will be built later + // after all nodes are added via BuildEdges() +} + +// GetNode retrieves a node by rule name +func (g *DependencyGraph) GetNode(ruleName string) *GraphNode { + return g.Nodes[ruleName] +} + +// ValidateGrammar checks if all non-recursive rules can reach terminal symbols +func (g *DependencyGraph) ValidateGrammar() error { + // For now, we trust that the grammar is well-formed + // Future: could add validation to ensure non-recursive rules can terminate + return nil +} + +// PrintAnalysisResults prints the dependency graph analysis results for debugging +func (g *DependencyGraph) PrintAnalysisResults() { + fmt.Println("=== Dependency Graph Analysis Results ===") + for ruleName, node := range g.Nodes { + fmt.Printf("Rule: %s (lexer=%t)\n", ruleName, node.IsLexer) + fmt.Printf(" IsRecursive: %t\n", node.IsRecursive) + fmt.Printf(" SCCID: %d, SCCSize: %d\n", node.SCCID, node.SCCSize) + fmt.Printf(" Total alternatives: %d\n", len(node.Alternatives)) + fmt.Println() + } +} + +// collectRuleReferences collects all rule references in an alternative +func (g *DependencyGraph) collectRuleReferences(alt Alternative, refs map[string]bool) { + for _, element := range alt.Elements { + g.collectElementReferences(element, refs) + } +} + +// collectElementReferences collects rule references from a single element +func (g *DependencyGraph) collectElementReferences(element Element, refs map[string]bool) { + if element.IsRule() { + switch value := element.Value.(type) { + case ReferenceValue: + refs[value.Name] = true + case BlockValue: + for _, alt := range value.Alternatives { + g.collectRuleReferences(alt, refs) + } + } + } +} + +// BuildEdges builds all edges after all nodes have been added +func (g *DependencyGraph) BuildEdges() { + g.Edges = make(map[string][]string) + + for ruleName, node := range g.Nodes { + referencedRules := make(map[string]bool) + + for _, alt := range node.Alternatives { + g.collectRuleReferences(alt, referencedRules) + } + + // Only add edges to parser rules (exclude lexer rules) + // But include all referenced parser rules, even if they don't exist yet + edges := []string{} + for ref := range referencedRules { + // Check if the referenced rule is a lexer rule + if refNode := g.GetNode(ref); refNode != nil && refNode.IsLexer { + continue // Skip lexer rules + } + // Add all other references (including forward references) + edges = append(edges, ref) + } + g.Edges[ruleName] = edges + } +} + +// ComputeSCCs computes strongly connected components using Tarjan's algorithm +func (g *DependencyGraph) ComputeSCCs() { + if len(g.Edges) == 0 { + g.BuildEdges() + } + + index := 0 + stack := []string{} + indices := make(map[string]int) + lowlinks := make(map[string]int) + onStack := make(map[string]bool) + + // Helper function for Tarjan's strongconnect + var strongconnect func(v string) + strongconnect = func(v string) { + // Set the depth index for v to the smallest unused index + indices[v] = index + lowlinks[v] = index + index++ + stack = append(stack, v) + onStack[v] = true + + // Consider successors of v + for _, w := range g.Edges[v] { + if _, ok := indices[w]; !ok { + // Successor w has not yet been visited; recurse on it + strongconnect(w) + if lowlinks[w] < lowlinks[v] { + lowlinks[v] = lowlinks[w] + } + } else if onStack[w] { + // Successor w is in stack S and hence in the current SCC + if indices[w] < lowlinks[v] { + lowlinks[v] = indices[w] + } + } + } + + // If v is a root node, pop the stack and print an SCC + if lowlinks[v] == indices[v] { + scc := []string{} + for { + w := stack[len(stack)-1] + stack = stack[:len(stack)-1] + onStack[w] = false + scc = append(scc, w) + if w == v { + break + } + } + g.SCCs = append(g.SCCs, scc) + } + } + + // Clear existing SCCs and lookup map + g.SCCs = [][]string{} + g.SCCLookup = make(map[string]int) + + // Run algorithm for all unvisited nodes + for ruleName := range g.Nodes { + if _, ok := indices[ruleName]; !ok { + strongconnect(ruleName) + } + } + + // Perform sanity check: ensure no SCC is an isolated island + // Only log warnings, don't fail - test cases often have isolated SCCs + if err := g.checkForIsolatedSCCs(); err != nil { + // Log warning but continue - tests may have intentionally isolated SCCs + fmt.Printf("Warning: %v\n", err) + } + + // Build SCC lookup map and update nodes with their SCC information + for sccID, scc := range g.SCCs { + sccSize := len(scc) + isRecursive := sccSize > 1 + + // Check for self-loops if single node SCC + if sccSize == 1 { + ruleName := scc[0] + for _, ref := range g.Edges[ruleName] { + if ref == ruleName { + isRecursive = true + break + } + } + } + + // Update lookup map and nodes in this SCC + for _, ruleName := range scc { + // Add to lookup map + g.SCCLookup[ruleName] = sccID + + // Update node information + if node := g.GetNode(ruleName); node != nil { + node.SCCID = sccID + node.SCCSize = sccSize + node.IsRecursive = isRecursive + } + } + } +} + +// checkForIsolatedSCCs ensures no SCC is an isolated island with no exit paths +func (g *DependencyGraph) checkForIsolatedSCCs() error { + // Create a temporary SCC membership map for this check + sccMembership := make(map[string]int) + for sccID, scc := range g.SCCs { + for _, ruleName := range scc { + sccMembership[ruleName] = sccID + } + } + + // Check each SCC for exit paths + isolatedSCCs := []int{} + for sccID, scc := range g.SCCs { + // Skip non-recursive SCCs (single nodes without self-loops) + if len(scc) == 1 { + ruleName := scc[0] + hasSelfLoop := false + for _, ref := range g.Edges[ruleName] { + if ref == ruleName { + hasSelfLoop = true + break + } + } + if !hasSelfLoop { + continue // Non-recursive single node, skip + } + } + + // Check if this SCC has any exit path + hasExit := g.sccHasExitPath(sccID, scc, sccMembership) + if !hasExit { + isolatedSCCs = append(isolatedSCCs, sccID) + } + } + + // Report error if any isolated SCCs found + if len(isolatedSCCs) > 0 { + fmt.Printf("\nERROR: Found %d isolated SCC(s) with no exit paths:\n", len(isolatedSCCs)) + for _, sccID := range isolatedSCCs { + fmt.Printf(" SCC %d: %v\n", sccID, g.SCCs[sccID]) + } + return fmt.Errorf("grammar contains %d isolated SCC(s) that cannot terminate", len(isolatedSCCs)) + } + + return nil +} + +// sccHasExitPath checks if an SCC has at least one path to rules outside of it +func (g *DependencyGraph) sccHasExitPath(sccID int, scc []string, sccMembership map[string]int) bool { + // Use fixed-point iteration to find reachable rules from this SCC + visited := make(map[string]bool) + toVisit := []string{} + + // Start with all rules in the SCC + for _, ruleName := range scc { + toVisit = append(toVisit, ruleName) + visited[ruleName] = true + } + + // Perform reachability analysis + for len(toVisit) > 0 { + current := toVisit[0] + toVisit = toVisit[1:] + + // Check all references from current rule + for _, ref := range g.Edges[current] { + // Skip if already visited + if visited[ref] { + continue + } + + // Check if referenced rule is outside this SCC + refSCCID, exists := sccMembership[ref] + if !exists || refSCCID != sccID { + // Found an exit! Check if it can eventually reach terminals + if g.canReachTerminal(ref, make(map[string]bool)) { + return true + } + } + + // Mark as visited and continue searching + visited[ref] = true + toVisit = append(toVisit, ref) + } + + // Also check alternatives for direct terminal paths + if node := g.GetNode(current); node != nil { + for _, alt := range node.Alternatives { + if g.alternativeHasTerminalPath(alt) { + return true + } + } + } + } + + return false +} + +// canReachTerminal checks if a rule can eventually reach terminal symbols +func (g *DependencyGraph) canReachTerminal(ruleName string, visited map[string]bool) bool { + // Avoid infinite recursion + if visited[ruleName] { + return false + } + visited[ruleName] = true + + node := g.GetNode(ruleName) + if node == nil { + return false + } + + // Lexer rules are terminals + if node.IsLexer { + return true + } + + // Check each alternative + for _, alt := range node.Alternatives { + if g.alternativeCanReachTerminal(alt, visited) { + return true + } + } + + return false +} + +// alternativeHasTerminalPath checks if an alternative has at least one terminal +func (g *DependencyGraph) alternativeHasTerminalPath(alt Alternative) bool { + for _, element := range alt.Elements { + if element.IsTerminal() { + return true + } + // Check if it's an optional/quantified element (can be skipped) + if element.Quantifier == ZERO_MORE || element.Quantifier == OPTIONAL_Q { + return true + } + } + return false +} + +// alternativeCanReachTerminal checks if an alternative can reach terminals +func (g *DependencyGraph) alternativeCanReachTerminal(alt Alternative, visited map[string]bool) bool { + if len(alt.Elements) == 0 { + return true // Empty alternative is terminal + } + + for _, element := range alt.Elements { + if element.IsTerminal() { + return true + } + + // Optional elements can be skipped + if element.Quantifier == ZERO_MORE || element.Quantifier == OPTIONAL_Q { + continue + } + + // Check if referenced rule can reach terminal + if refValue, ok := element.Value.(ReferenceValue); ok { + if !g.canReachTerminal(refValue.Name, visited) { + return false + } + } + } + + return true +} + +// PrintSCCAnalysis prints the SCC analysis results for debugging +func (g *DependencyGraph) PrintSCCAnalysis() { + fmt.Println("\n=== SCC Analysis Results ===") + fmt.Printf("Total SCCs: %d\n", len(g.SCCs)) + + recursiveSCCs := 0 + selfLoopSCCs := 0 + largestSCC := 0 + for i, scc := range g.SCCs { + if len(scc) > 1 { + recursiveSCCs++ + if len(scc) > largestSCC { + largestSCC = len(scc) + } + // Print first 5 multi-node SCCs with more detail + if recursiveSCCs <= 5 { + fmt.Printf("\nSCC %d (RECURSIVE - mutual, size=%d):\n", i, len(scc)) + // Print first 20 nodes of the SCC for better visibility + fmt.Printf(" Members: ") + for j, node := range scc { + if j < 20 { + fmt.Printf("%s ", node) + if j == 19 && len(scc) > 20 { + fmt.Printf("\n ... and %d more", len(scc)-20) + } + } + } + fmt.Println() + } + } else if len(scc) == 1 { + // Check for self-loop + ruleName := scc[0] + hasSelfLoop := false + for _, ref := range g.Edges[ruleName] { + if ref == ruleName { + hasSelfLoop = true + break + } + } + if hasSelfLoop { + selfLoopSCCs++ + if selfLoopSCCs <= 10 { // Print first 10 + fmt.Printf("SCC %d (RECURSIVE - self-loop): %s\n", i, ruleName) + } + } + } + } + + fmt.Printf("\nMutually recursive SCCs (size > 1): %d\n", recursiveSCCs) + if recursiveSCCs > 0 { + fmt.Printf("Largest SCC size: %d\n", largestSCC) + } + fmt.Printf("Self-loop SCCs (size = 1 with self-ref): %d\n", selfLoopSCCs) + fmt.Printf("Non-recursive SCCs: %d\n", len(g.SCCs)-recursiveSCCs-selfLoopSCCs) + + // Print sample of recursive rules + fmt.Println("\nSample recursive rules:") + count := 0 + for ruleName, node := range g.Nodes { + if node.IsRecursive && count < 10 { + fmt.Printf(" %s (SCC %d, size %d)\n", ruleName, node.SCCID, node.SCCSize) + count++ + } + } + fmt.Println("=============================") +} diff --git a/tools/fuzzing/internal/grammar/parser.go b/tools/fuzzing/internal/grammar/parser.go index cd43d1c..1ec1c8f 100644 --- a/tools/fuzzing/internal/grammar/parser.go +++ b/tools/fuzzing/internal/grammar/parser.go @@ -6,8 +6,8 @@ import ( "strings" "github.com/antlr4-go/antlr/v4" - "github.com/pkg/errors" grammar "github.com/bytebase/parser/tools/grammar" + "github.com/pkg/errors" ) // ParsedGrammar represents a parsed grammar with extracted rules @@ -18,6 +18,8 @@ type ParsedGrammar struct { // BlockAltMap stores temporary block rules for debugging // Key: block ID (e.g., "block_1_alts"), Value: the block alternatives BlockAltMap map[string][]Alternative + // DependencyGraph for recursion analysis + DependencyGraph *DependencyGraph } // Rule represents a grammar rule with its alternatives @@ -57,7 +59,7 @@ func (r ReferenceValue) String() string { return r.Name } // BlockValue represents a generated block (e.g., (',' column)*) type BlockValue struct { - ID string // Global unique ID like "block_1_alts" + ID string // Global unique ID like "block_1_alts" Alternatives []Alternative } @@ -75,7 +77,6 @@ func (b BlockValue) String() string { return b.ID } - // WildcardValue represents a wildcard (.) type WildcardValue struct{} @@ -91,15 +92,14 @@ type Element struct { type Quantifier int const ( - NONE Quantifier = iota - OPTIONAL_Q // ? - ZERO_MORE // * - ONE_MORE // + + NONE Quantifier = iota + OPTIONAL_Q // ? + ZERO_MORE // * + ONE_MORE // + ) -// ParseGrammarFile parses a .g4 file and extracts rules for fuzzing -func ParseGrammarFile(filePath string) (*ParsedGrammar, error) { - // Read file content +// parseGrammarFileWithoutDependencyGraph parses a .g4 file without building dependency graph +func parseGrammarFileWithoutDependencyGraph(filePath string) (*ParsedGrammar, error) { content, err := os.ReadFile(filePath) if err != nil { return nil, errors.Wrap(err, "failed to read grammar file") @@ -109,31 +109,20 @@ func ParseGrammarFile(filePath string) (*ParsedGrammar, error) { return nil, errors.New("grammar file is empty") } - // Create input stream input := antlr.NewInputStream(string(content)) - - // Create lexer lexer := grammar.NewANTLRv4Lexer(input) - // Add error listener errorListener := &GrammarErrorListener{} lexer.RemoveErrorListeners() lexer.AddErrorListener(errorListener) - // Create token stream stream := antlr.NewCommonTokenStream(lexer, 0) - - // Create parser parser := grammar.NewANTLRv4Parser(stream) - - // Add error listener to parser parser.RemoveErrorListeners() parser.AddErrorListener(errorListener) - // Parse the grammar tree := parser.GrammarSpec() - // Check for parsing errors if errorListener.HasErrors() { return nil, errors.Errorf("failed to parse grammar: %v", errorListener.GetErrors()) } @@ -142,18 +131,57 @@ func ParseGrammarFile(filePath string) (*ParsedGrammar, error) { return nil, errors.New("parser returned nil tree") } - // Extract rules from parse tree visitor := NewGrammarExtractorVisitor() visitor.VisitGrammarSpec(tree) + parsedGrammar := &ParsedGrammar{ + LexerRules: visitor.lexerRules, + ParserRules: visitor.parserRules, + FilePath: filePath, + BlockAltMap: visitor.blockAltMap, + DependencyGraph: nil, + } + + return parsedGrammar, nil +} + +// ParseGrammarFile parses a .g4 file and extracts rules for fuzzing (legacy method) +func ParseGrammarFile(filePath string) (*ParsedGrammar, error) { + parsedGrammar, err := parseGrammarFileWithoutDependencyGraph(filePath) + if err != nil { + return nil, err + } + parsedGrammar.DependencyGraph = NewDependencyGraph() + if err := buildDependencyGraph(parsedGrammar); err != nil { + return nil, fmt.Errorf("failed to build dependency graph: %w", err) + } - return &ParsedGrammar{ - LexerRules: visitor.lexerRules, - ParserRules: visitor.parserRules, - FilePath: filePath, - BlockAltMap: visitor.blockAltMap, - }, nil + return parsedGrammar, nil +} + +// buildDependencyGraph constructs the dependency graph for the parsed grammar +func buildDependencyGraph(parsedGrammar *ParsedGrammar) error { + return buildDependencyGraphWithValidation(parsedGrammar) +} + +// buildDependencyGraphWithValidation constructs the dependency graph with optional validation +func buildDependencyGraphWithValidation(parsedGrammar *ParsedGrammar) error { + // Add all lexer rules to the graph + for ruleName, rule := range parsedGrammar.LexerRules { + parsedGrammar.DependencyGraph.AddNode(ruleName, rule) + } + + // Add all parser rules to the graph + for ruleName, rule := range parsedGrammar.ParserRules { + parsedGrammar.DependencyGraph.AddNode(ruleName, rule) + } + + // Perform SCC computing. + parsedGrammar.DependencyGraph.ComputeSCCs() + parsedGrammar.DependencyGraph.PrintSCCAnalysis() + + return nil } // GetRule gets a rule by name from either lexer or parser rules @@ -193,35 +221,45 @@ func (g *ParsedGrammar) IsGeneratedBlock(name string) bool { // MergeGrammar merges another grammar into this one func (g *ParsedGrammar) MergeGrammar(other *ParsedGrammar) error { - // Merge lexer rules for name, rule := range other.LexerRules { if _, exists := g.LexerRules[name]; exists { return fmt.Errorf("duplicate lexer rule '%s' found in grammars '%s' and '%s'", name, g.FilePath, other.FilePath) } g.LexerRules[name] = rule } - - // Merge parser rules + for name, rule := range other.ParserRules { if _, exists := g.ParserRules[name]; exists { return fmt.Errorf("duplicate parser rule '%s' found in grammars '%s' and '%s'", name, g.FilePath, other.FilePath) } g.ParserRules[name] = rule } - - // Merge block alternatives map + for blockID, alternatives := range other.BlockAltMap { if _, exists := g.BlockAltMap[blockID]; exists { return fmt.Errorf("duplicate block ID '%s' found in grammars '%s' and '%s'", blockID, g.FilePath, other.FilePath) } g.BlockAltMap[blockID] = alternatives } - - // Update file path to indicate it's a merged grammar + if g.FilePath != other.FilePath { g.FilePath = fmt.Sprintf("%s + %s", g.FilePath, other.FilePath) } - + + return nil +} + +// MergeGrammarAndRebuildGraph merges another grammar and rebuilds the dependency graph (for single file merging) +func (g *ParsedGrammar) MergeGrammarAndRebuildGraph(other *ParsedGrammar) error { + if err := g.MergeGrammar(other); err != nil { + return err + } + + g.DependencyGraph = NewDependencyGraph() + if err := buildDependencyGraphWithValidation(g); err != nil { + return fmt.Errorf("failed to rebuild dependency graph after merge: %w", err) + } + return nil } @@ -230,29 +268,51 @@ func ParseAndMergeGrammarFiles(filePaths []string) (*ParsedGrammar, error) { if len(filePaths) == 0 { return nil, errors.New("no grammar files provided") } - - // Parse the first grammar file - mergedGrammar, err := ParseGrammarFile(filePaths[0]) - if err != nil { - return nil, errors.Wrapf(err, "failed to parse first grammar file %s", filePaths[0]) - } - - // Merge additional grammar files - for i := 1; i < len(filePaths); i++ { - filePath := filePaths[i] - grammar, err := ParseGrammarFile(filePath) + + grammars := make([]*ParsedGrammar, 0, len(filePaths)) + for _, filePath := range filePaths { + grammar, err := parseGrammarFileWithoutDependencyGraph(filePath) if err != nil { return nil, errors.Wrapf(err, "failed to parse grammar file %s", filePath) } - - if err := mergedGrammar.MergeGrammar(grammar); err != nil { - return nil, errors.Wrapf(err, "failed to merge grammar file %s", filePath) + grammars = append(grammars, grammar) + } + + mergedGrammar := grammars[0] + for i := 1; i < len(grammars); i++ { + if err := mergedGrammar.MergeGrammar(grammars[i]); err != nil { + return nil, errors.Wrapf(err, "failed to merge grammar file %s", grammars[i].FilePath) } } - + + mergedGrammar.DependencyGraph = NewDependencyGraph() + if err := buildDependencyGraphWithValidation(mergedGrammar); err != nil { + return nil, fmt.Errorf("failed to build dependency graph after merging all files: %w", err) + } + return mergedGrammar, nil } +// GetDependencyGraph returns the dependency graph for the parsed grammar +func (g *ParsedGrammar) GetDependencyGraph() *DependencyGraph { + return g.DependencyGraph +} + +// ValidateGrammar validates that the grammar has valid dependency structure +func (g *ParsedGrammar) ValidateGrammar() error { + if g.DependencyGraph == nil { + return fmt.Errorf("dependency graph not built") + } + return g.DependencyGraph.ValidateGrammar() +} + +// PrintDependencyAnalysis prints dependency graph analysis for debugging +func (g *ParsedGrammar) PrintDependencyAnalysis() { + if g.DependencyGraph != nil { + g.DependencyGraph.PrintAnalysisResults() + } +} + // IsRule checks if an element refers to another rule or generated block func (e *Element) IsRule() bool { _, isRef := e.Value.(ReferenceValue) @@ -274,7 +334,7 @@ func (e *Element) IsOptional() bool { // IsQuantified checks if an element has repetition quantifiers func (e *Element) IsQuantified() bool { - return e.Quantifier == ZERO_MORE || e.Quantifier == ONE_MORE + return e.Quantifier == ZERO_MORE || e.Quantifier == ONE_MORE || e.Quantifier == OPTIONAL_Q } // GrammarErrorListener collects parsing errors @@ -568,7 +628,7 @@ func (v *GrammarExtractorVisitor) extractLexerAtom(lexerAtomCtx grammar.ILexerAt // Handle not set (e.g., ~[abc]) if notSetCtx := lexerAtomCtx.NotSet(); notSetCtx != nil { - return v.extractNotSet(notSetCtx) + return v.extractNotSet() } // Handle lexer character set (e.g., [abc]) @@ -597,7 +657,7 @@ func (v *GrammarExtractorVisitor) extractLexerBlock(lexerBlockCtx grammar.ILexer blockID := fmt.Sprintf("lexer_block_%d_alts", globalBlockID) emptyAlts := []Alternative{} v.blockAltMap[blockID] = emptyAlts - + return &Element{ Value: BlockValue{ID: blockID, Alternatives: emptyAlts}, } @@ -610,7 +670,7 @@ func (v *GrammarExtractorVisitor) extractLexerBlock(lexerBlockCtx grammar.ILexer blockID := fmt.Sprintf("lexer_block_%d_alts", globalBlockID) emptyAlts := []Alternative{} v.blockAltMap[blockID] = emptyAlts - + return &Element{ Value: BlockValue{ID: blockID, Alternatives: emptyAlts}, } @@ -630,12 +690,12 @@ func (v *GrammarExtractorVisitor) extractLexerBlock(lexerBlockCtx grammar.ILexer } blockAlternatives = append(blockAlternatives, Alternative{Elements: elements}) } - + // Generate global unique block ID and store mapping globalBlockID++ blockID := fmt.Sprintf("lexer_block_%d_alts", globalBlockID) v.blockAltMap[blockID] = blockAlternatives - + return &Element{ Value: BlockValue{ID: blockID, Alternatives: blockAlternatives}, } @@ -657,7 +717,7 @@ func (v *GrammarExtractorVisitor) extractCharacterRange(characterRangeCtx gramma } // extractNotSet extracts a not set (e.g., ~[abc]) -func (v *GrammarExtractorVisitor) extractNotSet(notSetCtx grammar.INotSetContext) *Element { +func (v *GrammarExtractorVisitor) extractNotSet() *Element { // For now, represent as a literal text // In a real implementation, this would need more sophisticated handling return &Element{ @@ -715,7 +775,6 @@ func (v *GrammarExtractorVisitor) extractTerminalDef(terminalDefCtx grammar.ITer return nil } - // extractRuleRef extracts a rule reference func (v *GrammarExtractorVisitor) extractRuleRef(rulerefCtx grammar.IRulerefContext) *Element { if ruleRefToken := rulerefCtx.RULE_REF(); ruleRefToken != nil { @@ -735,7 +794,7 @@ func (v *GrammarExtractorVisitor) extractBlock(blockCtx grammar.IBlockContext) * blockID := fmt.Sprintf("block_%d_alts", globalBlockID) emptyAlts := []Alternative{} v.blockAltMap[blockID] = emptyAlts - + return &Element{ Value: BlockValue{ID: blockID, Alternatives: emptyAlts}, } @@ -748,7 +807,7 @@ func (v *GrammarExtractorVisitor) extractBlock(blockCtx grammar.IBlockContext) * blockID := fmt.Sprintf("block_%d_alts", globalBlockID) emptyAlts := []Alternative{} v.blockAltMap[blockID] = emptyAlts - + return &Element{ Value: BlockValue{ID: blockID, Alternatives: emptyAlts}, } @@ -771,12 +830,12 @@ func (v *GrammarExtractorVisitor) extractBlock(blockCtx grammar.IBlockContext) * if len(blockAlternatives) == 1 && len(blockAlternatives[0].Elements) == 1 { return &blockAlternatives[0].Elements[0] } - + // Generate global unique block ID and store mapping globalBlockID++ blockID := fmt.Sprintf("block_%d_alts", globalBlockID) v.blockAltMap[blockID] = blockAlternatives - + return &Element{ Value: BlockValue{ID: blockID, Alternatives: blockAlternatives}, } @@ -825,4 +884,4 @@ func (v *GrammarExtractorVisitor) extractQuantifier(ebnfSuffixCtx grammar.IEbnfS } return NONE -} \ No newline at end of file +} diff --git a/tools/fuzzing/internal/grammar/scc_test.go b/tools/fuzzing/internal/grammar/scc_test.go new file mode 100644 index 0000000..27c0672 --- /dev/null +++ b/tools/fuzzing/internal/grammar/scc_test.go @@ -0,0 +1,333 @@ +package grammar + +import ( + "testing" +) + +// TestSCCDetection tests the SCC detection algorithm with various graph patterns +func TestSCCDetection(t *testing.T) { + tests := []struct { + name string + rules map[string][]string // rule -> references + expectedSCCs [][]string // expected SCCs + recursiveRules map[string]bool // which rules should be marked recursive + }{ + { + name: "Simple self-loop", + rules: map[string][]string{ + "a": {"a"}, + }, + expectedSCCs: [][]string{ + {"a"}, + }, + recursiveRules: map[string]bool{ + "a": true, + }, + }, + { + name: "Mutual recursion (2 nodes)", + rules: map[string][]string{ + "a": {"b"}, + "b": {"a"}, + }, + expectedSCCs: [][]string{ + {"b", "a"}, // Order might vary due to algorithm + }, + recursiveRules: map[string]bool{ + "a": true, + "b": true, + }, + }, + { + name: "Cycle of 3 nodes", + rules: map[string][]string{ + "a": {"b"}, + "b": {"c"}, + "c": {"a"}, + }, + expectedSCCs: [][]string{ + {"c", "b", "a"}, + }, + recursiveRules: map[string]bool{ + "a": true, + "b": true, + "c": true, + }, + }, + { + name: "Non-recursive with reference", + rules: map[string][]string{ + "a": {"b"}, + "b": {"c"}, + "c": {}, + }, + expectedSCCs: [][]string{ + {"c"}, + {"b"}, + {"a"}, + }, + recursiveRules: map[string]bool{ + "a": false, + "b": false, + "c": false, + }, + }, + { + name: "Multiple SCCs", + rules: map[string][]string{ + "a": {"b"}, + "b": {"a"}, + "c": {"d"}, + "d": {"c"}, + "e": {}, + }, + expectedSCCs: [][]string{ + {"b", "a"}, + {"d", "c"}, + {"e"}, + }, + recursiveRules: map[string]bool{ + "a": true, + "b": true, + "c": true, + "d": true, + "e": false, + }, + }, + { + name: "Complex with bridge", + rules: map[string][]string{ + "a": {"b", "c"}, + "b": {"a"}, + "c": {"d"}, + "d": {"e"}, + "e": {"c"}, + }, + expectedSCCs: [][]string{ + {"b", "a"}, + {"e", "d", "c"}, + }, + recursiveRules: map[string]bool{ + "a": true, + "b": true, + "c": true, + "d": true, + "e": true, + }, + }, + { + name: "Self-loop with external reference", + rules: map[string][]string{ + "expr": {"expr", "literal"}, + "literal": {}, + }, + expectedSCCs: [][]string{ + {"expr"}, + {"literal"}, + }, + recursiveRules: map[string]bool{ + "expr": true, + "literal": false, + }, + }, + { + name: "PostgreSQL-like pattern", + rules: map[string][]string{ + "select_with_parens": {"select_no_parens", "select_with_parens"}, + "select_no_parens": {"table_ref"}, + "table_ref": {"joined_table", "table_ref"}, + "joined_table": {"table_ref"}, + }, + expectedSCCs: [][]string{ + {"select_with_parens"}, + {"joined_table", "table_ref"}, + {"select_no_parens"}, + }, + recursiveRules: map[string]bool{ + "select_with_parens": true, + "select_no_parens": false, + "table_ref": true, + "joined_table": true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create dependency graph + g := NewDependencyGraph() + + // Add nodes + for ruleName := range tt.rules { + rule := &Rule{ + Name: ruleName, + Alternatives: []Alternative{}, + IsLexer: false, + } + // We need to add the node without building edges automatically + node := &GraphNode{ + RuleName: ruleName, + Alternatives: rule.Alternatives, + IsLexer: false, + SCCID: NoSCC, + SCCSize: 0, + IsRecursive: false, + } + g.Nodes[ruleName] = node + } + + // Set up edges manually + g.Edges = tt.rules + + // Compute SCCs + g.ComputeSCCs() + + // Verify number of SCCs + if len(g.SCCs) != len(tt.expectedSCCs) { + t.Errorf("Expected %d SCCs, got %d", len(tt.expectedSCCs), len(g.SCCs)) + t.Logf("SCCs found: %v", g.SCCs) + } + + // Verify each node's recursive status + for ruleName, expectedRecursive := range tt.recursiveRules { + node := g.GetNode(ruleName) + if node == nil { + t.Errorf("Node %s not found", ruleName) + continue + } + + if node.IsRecursive != expectedRecursive { + t.Errorf("Node %s: expected IsRecursive=%v, got %v (SCCID=%d, SCCSize=%d)", + ruleName, expectedRecursive, node.IsRecursive, node.SCCID, node.SCCSize) + } + } + + // Verify all nodes in same SCC have same SCCID + sccNodeMap := make(map[int][]string) + for ruleName, node := range g.Nodes { + if node.SCCID >= 0 { + sccNodeMap[node.SCCID] = append(sccNodeMap[node.SCCID], ruleName) + } + } + + // Log SCC information for debugging + t.Logf("SCCs detected:") + for sccID, nodes := range sccNodeMap { + t.Logf(" SCC %d: %v", sccID, nodes) + } + }) + } +} + +// TestSCCEdgeBuilding tests that edges are correctly built from grammar rules +func TestSCCEdgeBuilding(t *testing.T) { + // Create a simple grammar with references + g := NewDependencyGraph() + + // Add lexer rule (should not create edges) + lexerRule := &Rule{ + Name: "ID", + IsLexer: true, + Alternatives: []Alternative{ + { + Elements: []Element{ + {Value: LiteralValue{Text: "[a-zA-Z]+"}}, + }, + }, + }, + } + g.AddNode("ID", lexerRule) + + // Add parser rule with references + parserRule := &Rule{ + Name: "expr", + IsLexer: false, + Alternatives: []Alternative{ + { + Elements: []Element{ + {Value: ReferenceValue{Name: "expr"}}, + {Value: LiteralValue{Text: "+"}}, + {Value: ReferenceValue{Name: "term"}}, + }, + }, + { + Elements: []Element{ + {Value: ReferenceValue{Name: "term"}}, + }, + }, + }, + } + + // Add term rule + termRule := &Rule{ + Name: "term", + IsLexer: false, + Alternatives: []Alternative{ + { + Elements: []Element{ + {Value: ReferenceValue{Name: "ID"}}, // Reference to lexer + }, + }, + { + Elements: []Element{ + {Value: LiteralValue{Text: "123"}}, + }, + }, + }, + } + + // Need to add term first so it exists when expr references it + g.AddNode("term", termRule) + g.AddNode("expr", parserRule) + + // Build edges after adding all nodes + g.BuildEdges() + + // Verify edges + // expr should have edges to: expr (self), term + exprEdges := g.Edges["expr"] + if len(exprEdges) == 0 { + t.Error("expr should have edges") + } + + hasExprEdge := false + hasTermEdge := false + for _, edge := range exprEdges { + if edge == "expr" { + hasExprEdge = true + } + if edge == "term" { + hasTermEdge = true + } + } + + if !hasExprEdge { + t.Error("expr should have self-edge") + } + if !hasTermEdge { + t.Error("expr should have edge to term") + } + + // term should NOT have edge to ID (lexer rule) + termEdges := g.Edges["term"] + for _, edge := range termEdges { + if edge == "ID" { + t.Error("term should not have edge to lexer rule ID") + } + } + + // Compute SCCs and verify + g.ComputeSCCs() + + // expr should be recursive (self-loop) + exprNode := g.GetNode("expr") + if !exprNode.IsRecursive { + t.Error("expr should be marked as recursive due to self-loop") + } + + // term should not be recursive + termNode := g.GetNode("term") + if termNode.IsRecursive { + t.Error("term should not be marked as recursive") + } +} \ No newline at end of file diff --git a/tools/fuzzing/tests/postgresql_test.go b/tools/fuzzing/tests/postgresql_test.go index fa067a4..aaa8363 100644 --- a/tools/fuzzing/tests/postgresql_test.go +++ b/tools/fuzzing/tests/postgresql_test.go @@ -17,9 +17,9 @@ func getRepoRoot() string { return filepath.Join(filepath.Dir(filename), "..", "..", "..") } -func TestPostgreSQLSelectStmt(t *testing.T) { +func TestPostgreSQLRootStmt(t *testing.T) { repoRoot := getRepoRoot() - + // PostgreSQL grammar file paths lexerPath := filepath.Join(repoRoot, "postgresql", "PostgreSQLLexer.g4") parserPath := filepath.Join(repoRoot, "postgresql", "PostgreSQLParser.g4") @@ -33,24 +33,24 @@ func TestPostgreSQLSelectStmt(t *testing.T) { seed int64 }{ { - name: "Simple SELECT statements", - startRule: "selectstmt", + name: "Simple root", + startRule: "root", count: 3, - maxDepth: 5, + maxDepth: 10, optionalProb: 0.7, seed: 42, }, { - name: "Deep SELECT statements", - startRule: "selectstmt", + name: "Deep root", + startRule: "root", count: 2, maxDepth: 8, optionalProb: 0.5, seed: 123, }, { - name: "Minimal SELECT statements", - startRule: "selectstmt", + name: "Minimal root", + startRule: "root", count: 5, maxDepth: 3, optionalProb: 0.3, @@ -74,7 +74,7 @@ func TestPostgreSQLSelectStmt(t *testing.T) { } fmt.Printf("\n=== %s ===\n", tt.name) - fmt.Printf("Config: MaxDepth=%d, OptionalProb=%.1f, Count=%d, Seed=%d\n", + fmt.Printf("Config: MaxDepth=%d, OptionalProb=%.1f, Count=%d, Seed=%d\n", tt.maxDepth, tt.optionalProb, tt.count, tt.seed) fmt.Println() @@ -89,104 +89,3 @@ func TestPostgreSQLSelectStmt(t *testing.T) { }) } } - -func TestPostgreSQLExpressions(t *testing.T) { - repoRoot := getRepoRoot() - - // PostgreSQL grammar file paths - lexerPath := filepath.Join(repoRoot, "postgresql", "PostgreSQLLexer.g4") - parserPath := filepath.Join(repoRoot, "postgresql", "PostgreSQLParser.g4") - - cfg := &config.Config{ - GrammarFiles: []string{lexerPath, parserPath}, - StartRule: "a_expr", // PostgreSQL expression rule - Count: 5, - MaxDepth: 4, - OptionalProb: 0.6, - MaxQuantifier: 2, - MinQuantifier: 1, - QuantifierCount: 0, - OutputFormat: config.CompactOutput, - Seed: 789, - } - - fmt.Printf("\n=== PostgreSQL Expressions ===\n") - fmt.Printf("Generating %d expressions with max depth %d\n", cfg.Count, cfg.MaxDepth) - fmt.Println() - - gen := generator.New(cfg) - err := gen.Generate() - - if err != nil { - t.Errorf("Failed to generate PostgreSQL expressions: %v", err) - } else { - t.Logf("Successfully generated %d PostgreSQL expressions", cfg.Count) - } -} - -func TestPostgreSQLVerboseOutput(t *testing.T) { - repoRoot := getRepoRoot() - - // PostgreSQL grammar file paths - lexerPath := filepath.Join(repoRoot, "postgresql", "PostgreSQLLexer.g4") - parserPath := filepath.Join(repoRoot, "postgresql", "PostgreSQLParser.g4") - - cfg := &config.Config{ - GrammarFiles: []string{lexerPath, parserPath}, - StartRule: "selectstmt", - Count: 2, - MaxDepth: 4, - OptionalProb: 0.8, - MaxQuantifier: 2, - MinQuantifier: 1, - QuantifierCount: 0, - OutputFormat: config.VerboseOutput, // Show rule traversal - Seed: 999, - } - - fmt.Printf("\n=== PostgreSQL Verbose Output ===\n") - fmt.Printf("Generating with verbose output to show rule traversal\n") - fmt.Println() - - gen := generator.New(cfg) - err := gen.Generate() - - if err != nil { - t.Errorf("Failed to generate PostgreSQL statements with verbose output: %v", err) - } else { - t.Logf("Successfully generated PostgreSQL statements with verbose output") - } -} - -// Benchmark test for performance measurement -func BenchmarkPostgreSQLGeneration(b *testing.B) { - repoRoot := getRepoRoot() - - lexerPath := filepath.Join(repoRoot, "postgresql", "PostgreSQLLexer.g4") - parserPath := filepath.Join(repoRoot, "postgresql", "PostgreSQLParser.g4") - - cfg := &config.Config{ - GrammarFiles: []string{lexerPath, parserPath}, - StartRule: "selectstmt", - Count: 1, - MaxDepth: 6, - OptionalProb: 0.5, - MaxQuantifier: 3, - MinQuantifier: 1, - QuantifierCount: 0, - OutputFormat: config.CompactOutput, - Seed: 42, - } - - gen := generator.New(cfg) - - // Reset the timer to exclude setup time - b.ResetTimer() - - for i := 0; i < b.N; i++ { - err := gen.Generate() - if err != nil { - b.Fatalf("Generation failed: %v", err) - } - } -} \ No newline at end of file