Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 4 additions & 29 deletions directed.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package graph

import (
"errors"
"fmt"
)

Expand Down Expand Up @@ -76,28 +75,14 @@ func (d *directed[K, T]) RemoveVertex(hash K) error {
}

func (d *directed[K, T]) AddEdge(sourceHash, targetHash K, options ...func(*EdgeProperties)) error {
_, _, err := d.store.Vertex(sourceHash)
if err != nil {
return fmt.Errorf("source vertex %v: %w", sourceHash, err)
}

_, _, err = d.store.Vertex(targetHash)
if err != nil {
return fmt.Errorf("target vertex %v: %w", targetHash, err)
}

if _, err := d.Edge(sourceHash, targetHash); !errors.Is(err, ErrEdgeNotFound) {
return ErrEdgeAlreadyExists
}

// If the user opted in to preventing cycles, run a cycle check.
if d.traits.PreventCycles {
createsCycle, err := d.createsCycle(sourceHash, targetHash)
if err != nil {
return fmt.Errorf("check for cycles: %w", err)
}
if createsCycle {
return ErrEdgeCreatesCycle
return &EdgeCausesCycleError[K]{Source: sourceHash, Target: targetHash}
}
}

Expand Down Expand Up @@ -176,10 +161,6 @@ func (d *directed[K, T]) UpdateEdge(source, target K, options ...func(properties
}

func (d *directed[K, T]) RemoveEdge(source, target K) error {
if _, err := d.Edge(source, target); err != nil {
return err
}

if err := d.store.RemoveEdge(source, target); err != nil {
return fmt.Errorf("failed to remove edge from %v to %v: %w", source, target, err)
}
Expand Down Expand Up @@ -273,17 +254,11 @@ func (d *directed[K, T]) Order() (int, error) {
}

func (d *directed[K, T]) Size() (int, error) {
size := 0
outEdges, err := d.AdjacencyMap()
edges, err := d.store.ListEdges()
if err != nil {
return 0, fmt.Errorf("failed to get adjacency map: %w", err)
return 0, fmt.Errorf("failed to list edges: %w", err)
}

for _, outEdges := range outEdges {
size += len(outEdges)
}

return size, nil
return len(edges), nil
}

func (d *directed[K, T]) edgesAreEqual(a, b Edge[T]) bool {
Expand Down
6 changes: 4 additions & 2 deletions directed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ func TestDirected_RemoveEdge(t *testing.T) {
removeEdges: []Edge[int]{
{Source: 2, Target: 3},
},
expectedError: ErrEdgeNotFound,
// Expect no error because memoryStore doesn't error
},
}

Expand All @@ -909,7 +909,7 @@ func TestDirected_RemoveEdge(t *testing.T) {
}
// After removing the edge, verify that it can't be retrieved using
// Edge anymore.
if _, err := graph.Edge(removeEdge.Source, removeEdge.Target); err != ErrEdgeNotFound {
if _, err := graph.Edge(removeEdge.Source, removeEdge.Target); !errors.Is(err, ErrEdgeNotFound) {
t.Fatalf("%s: error expectancy doesn't match: expected %v, got %v", name, ErrEdgeNotFound, err)
}
}
Expand Down Expand Up @@ -1267,6 +1267,8 @@ func TestDirected_addEdge(t *testing.T) {
graph := newDirected(IntHash, &Traits{}, newMemoryStore[int, int]())

for _, edge := range test.edges {
_ = graph.AddVertex(edge.Source)
_ = graph.AddVertex(edge.Target)
sourceHash := graph.hash(edge.Source)
TargetHash := graph.hash(edge.Target)
err := graph.addEdge(sourceHash, TargetHash, edge)
Expand Down
74 changes: 74 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package graph

import (
"errors"
"fmt"
)

type (
VertexAlreadyExistsError[K comparable, T any] struct {
Key K
ExistingValue T
}

VertexNotFoundError[K comparable] struct {
Key K
}

EdgeAlreadyExistsError[K comparable] struct {
Source, Target K
}

EdgeNotFoundError[K comparable] struct {
Source, Target K
}

VertexHasEdgesError[K comparable] struct {
Key K
Count int
}

EdgeCausesCycleError[K comparable] struct {
Source, Target K
}
)

func (e *VertexAlreadyExistsError[K, T]) Error() string {
return fmt.Sprintf("vertex %v already exists with value %v", e.Key, e.ExistingValue)
}

func (e *VertexNotFoundError[K]) Error() string {
return fmt.Sprintf("vertex %v not found", e.Key)
}

func (e *EdgeAlreadyExistsError[K]) Error() string {
return fmt.Sprintf("edge %v - %v already exists", e.Source, e.Target)
}

func (e *EdgeNotFoundError[K]) Error() string {
return fmt.Sprintf("edge %v - %v not found", e.Source, e.Target)
}

func (e *VertexHasEdgesError[K]) Error() string {
return fmt.Sprintf("vertex %v has %d edges", e.Key, e.Count)
}

func (e *EdgeCausesCycleError[K]) Error() string {
return fmt.Sprintf("edge %v - %v would cause a cycle", e.Source, e.Target)
}

var (
ErrVertexNotFound = errors.New("vertex not found")
ErrVertexAlreadyExists = errors.New("vertex already exists")
ErrEdgeNotFound = errors.New("edge not found")
ErrEdgeAlreadyExists = errors.New("edge already exists")
ErrEdgeCreatesCycle = errors.New("edge would create a cycle")
ErrVertexHasEdges = errors.New("vertex has edges")
)

func (e *VertexAlreadyExistsError[K, T]) Unwrap() error { return ErrVertexAlreadyExists }
func (e *VertexNotFoundError[K]) Unwrap() error { return ErrVertexNotFound }
func (e *EdgeAlreadyExistsError[K]) Unwrap() error { return ErrEdgeAlreadyExists }
func (e *EdgeNotFoundError[K]) Unwrap() error { return ErrEdgeNotFound }
func (e *VertexHasEdgesError[K]) Unwrap() error { return ErrVertexHasEdges }
func (e *EdgeCausesCycleError[K]) Unwrap() error { return ErrEdgeCreatesCycle }
11 changes: 0 additions & 11 deletions graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,6 @@
// For detailed usage examples, take a look at the README.
package graph

import "errors"

var (
ErrVertexNotFound = errors.New("vertex not found")
ErrVertexAlreadyExists = errors.New("vertex already exists")
ErrEdgeNotFound = errors.New("edge not found")
ErrEdgeAlreadyExists = errors.New("edge already exists")
ErrEdgeCreatesCycle = errors.New("edge would create a cycle")
ErrVertexHasEdges = errors.New("vertex has edges")
)

// Graph represents a generic graph data structure consisting of vertices of
// type T identified by a hash of type K.
type Graph[K comparable, T any] interface {
Expand Down
4 changes: 2 additions & 2 deletions paths.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ var ErrTargetNotReachable = errors.New("target vertex not reachable from source"
// of the source vertex. In order to determine this, CreatesCycle runs a DFS.
func CreatesCycle[K comparable, T any](g Graph[K, T], source, target K) (bool, error) {
if _, err := g.Vertex(source); err != nil {
return false, fmt.Errorf("could not get vertex with hash %v: %w", source, err)
return false, fmt.Errorf("could not get source vertex: %w", err)
}

if _, err := g.Vertex(target); err != nil {
return false, fmt.Errorf("could not get vertex with hash %v: %w", target, err)
return false, fmt.Errorf("could not get target vertex: %w", err)
}

if source == target {
Expand Down
67 changes: 49 additions & 18 deletions store.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ func (s *memoryStore[K, T]) AddVertex(k K, t T, p VertexProperties) error {
s.lock.Lock()
defer s.lock.Unlock()

if _, ok := s.vertices[k]; ok {
return ErrVertexAlreadyExists
if existing, ok := s.vertices[k]; ok {
return &VertexAlreadyExistsError[K, T]{
Key: k,
ExistingValue: existing,
}
}

s.vertices[k] = t
Expand Down Expand Up @@ -120,10 +123,15 @@ func (s *memoryStore[K, T]) VertexCount() (int, error) {
func (s *memoryStore[K, T]) Vertex(k K) (T, VertexProperties, error) {
s.lock.RLock()
defer s.lock.RUnlock()
return s.vertexWithLock(k)
}

// vertexWithLock returns the vertex and vertex properties - the caller must be holding at least a
// read-level lock.
func (s *memoryStore[K, T]) vertexWithLock(k K) (T, VertexProperties, error) {
v, ok := s.vertices[k]
if !ok {
return v, VertexProperties{}, ErrVertexNotFound
return v, VertexProperties{}, &VertexNotFoundError[K]{Key: k}
}

p := s.vertexProperties[k]
Expand All @@ -136,19 +144,19 @@ func (s *memoryStore[K, T]) RemoveVertex(k K) error {
defer s.lock.RUnlock()

if _, ok := s.vertices[k]; !ok {
return ErrVertexNotFound
return &VertexNotFoundError[K]{Key: k}
}

if edges, ok := s.inEdges[k]; ok {
if len(edges) > 0 {
return ErrVertexHasEdges
if count := len(edges); count > 0 {
return &VertexHasEdgesError[K]{Key: k, Count: count}
}
delete(s.inEdges, k)
}

if edges, ok := s.outEdges[k]; ok {
if len(edges) > 0 {
return ErrVertexHasEdges
if count := len(edges); count > 0 {
return &VertexHasEdgesError[K]{Key: k, Count: count}
}
delete(s.outEdges, k)
}
Expand All @@ -163,29 +171,45 @@ func (s *memoryStore[K, T]) AddEdge(sourceHash, targetHash K, edge Edge[K]) erro
s.lock.Lock()
defer s.lock.Unlock()

if _, _, err := s.vertexWithLock(sourceHash); err != nil {
return fmt.Errorf("could not get source vertex: %w", &VertexNotFoundError[K]{Key: sourceHash})
}

if _, ok := s.outEdges[sourceHash]; !ok {
s.outEdges[sourceHash] = make(map[K]Edge[K])
}

if _, ok := s.outEdges[sourceHash][targetHash]; ok {
return &EdgeAlreadyExistsError[K]{Source: sourceHash, Target: targetHash}
}

s.outEdges[sourceHash][targetHash] = edge

if _, _, err := s.vertexWithLock(targetHash); err != nil {
return fmt.Errorf("could not get target vertex: %w", &VertexNotFoundError[K]{Key: targetHash})
}

if _, ok := s.inEdges[targetHash]; !ok {
s.inEdges[targetHash] = make(map[K]Edge[K])
}

if _, ok := s.inEdges[targetHash][sourceHash]; ok {
return &EdgeAlreadyExistsError[K]{Source: sourceHash, Target: targetHash}
}

s.inEdges[targetHash][sourceHash] = edge

return nil
}

func (s *memoryStore[K, T]) UpdateEdge(sourceHash, targetHash K, edge Edge[K]) error {
if _, err := s.Edge(sourceHash, targetHash); err != nil {
return err
}

s.lock.Lock()
defer s.lock.Unlock()

if _, err := s.edgeWithLock(sourceHash, targetHash); err != nil {
return err
}

s.outEdges[sourceHash][targetHash] = edge
s.inEdges[targetHash][sourceHash] = edge

Expand All @@ -204,15 +228,19 @@ func (s *memoryStore[K, T]) RemoveEdge(sourceHash, targetHash K) error {
func (s *memoryStore[K, T]) Edge(sourceHash, targetHash K) (Edge[K], error) {
s.lock.RLock()
defer s.lock.RUnlock()
return s.edgeWithLock(sourceHash, targetHash)
}

// edgeWithLock returns the edge - the caller must be holding at least a read-level lock.
func (s *memoryStore[K, T]) edgeWithLock(sourceHash, targetHash K) (Edge[K], error) {
sourceEdges, ok := s.outEdges[sourceHash]
if !ok {
return Edge[K]{}, ErrEdgeNotFound
return Edge[K]{}, &EdgeNotFoundError[K]{Source: sourceHash, Target: targetHash}
}

edge, ok := sourceEdges[targetHash]
if !ok {
return Edge[K]{}, ErrEdgeNotFound
return Edge[K]{}, &EdgeNotFoundError[K]{Source: sourceHash, Target: targetHash}
}

return edge, nil
Expand All @@ -237,12 +265,15 @@ func (s *memoryStore[K, T]) ListEdges() ([]Edge[K], error) {
// Because CreatesCycle doesn't need to modify the PredecessorMap, we can use
// inEdges instead to compute the same thing without creating any copies.
func (s *memoryStore[K, T]) CreatesCycle(source, target K) (bool, error) {
if _, _, err := s.Vertex(source); err != nil {
return false, fmt.Errorf("could not get vertex with hash %v: %w", source, err)
s.lock.RLock()
defer s.lock.RUnlock()

if _, _, err := s.vertexWithLock(source); err != nil {
return false, fmt.Errorf("could not get source vertex: %w", err)
}

if _, _, err := s.Vertex(target); err != nil {
return false, fmt.Errorf("could not get vertex with hash %v: %w", target, err)
if _, _, err := s.vertexWithLock(target); err != nil {
return false, fmt.Errorf("could not get target vertex: %w", err)
}

if source == target {
Expand Down
Loading