|
4 | 4 | package plugin6
|
5 | 5 |
|
6 | 6 | import (
|
| 7 | + "bytes" |
7 | 8 | "context"
|
8 | 9 | "errors"
|
9 | 10 | "fmt"
|
@@ -1478,6 +1479,140 @@ func (p *GRPCProvider) ConfigureStateStore(r providers.ConfigureStateStoreReques
|
1478 | 1479 | return resp
|
1479 | 1480 | }
|
1480 | 1481 |
|
| 1482 | +func (p *GRPCProvider) ReadStateBytes(r providers.ReadStateBytesRequest) (resp providers.ReadStateBytesResponse) { |
| 1483 | + logger.Trace("GRPCProvider.v6: ReadStateBytes") |
| 1484 | + |
| 1485 | + schema := p.GetProviderSchema() |
| 1486 | + if schema.Diagnostics.HasErrors() { |
| 1487 | + resp.Diagnostics = schema.Diagnostics |
| 1488 | + return resp |
| 1489 | + } |
| 1490 | + |
| 1491 | + if _, ok := schema.StateStores[r.TypeName]; !ok { |
| 1492 | + resp.Diagnostics = resp.Diagnostics.Append(fmt.Errorf("unknown state store type %q", r.TypeName)) |
| 1493 | + return resp |
| 1494 | + } |
| 1495 | + |
| 1496 | + protoReq := &proto6.ReadStateBytes_Request{ |
| 1497 | + TypeName: r.TypeName, |
| 1498 | + StateId: r.StateId, |
| 1499 | + } |
| 1500 | + |
| 1501 | + // Start the streaming RPC with a context. The context will be cancelled |
| 1502 | + // when this function returns, which will stop the stream if it is still |
| 1503 | + // running. |
| 1504 | + ctx, cancel := context.WithCancel(p.ctx) |
| 1505 | + defer cancel() |
| 1506 | + |
| 1507 | + client, err := p.client.ReadStateBytes(ctx, protoReq) |
| 1508 | + if err != nil { |
| 1509 | + resp.Diagnostics = resp.Diagnostics.Append(grpcErr(err)) |
| 1510 | + return resp |
| 1511 | + } |
| 1512 | + |
| 1513 | + var buf *bytes.Buffer |
| 1514 | + var expectedTotalLength int |
| 1515 | + for { |
| 1516 | + chunk, err := client.Recv() |
| 1517 | + if err == io.EOF { |
| 1518 | + // End of stream, we're done |
| 1519 | + break |
| 1520 | + } |
| 1521 | + if err != nil { |
| 1522 | + resp.Diagnostics = resp.Diagnostics.Append(err) |
| 1523 | + break |
| 1524 | + } |
| 1525 | + resp.Diagnostics = resp.Diagnostics.Append(convert.ProtoToDiagnostics(chunk.Diagnostics)) |
| 1526 | + if resp.Diagnostics.HasErrors() { |
| 1527 | + // If we have errors, we stop processing and return early |
| 1528 | + break |
| 1529 | + } |
| 1530 | + |
| 1531 | + if expectedTotalLength == 0 { |
| 1532 | + expectedTotalLength = int(chunk.TotalLength) |
| 1533 | + } |
| 1534 | + logger.Trace("GRPCProvider.v6: ReadStateBytes: received chunk for range", chunk.Range) |
| 1535 | + |
| 1536 | + n, err := buf.Write(chunk.Bytes) |
| 1537 | + if err != nil { |
| 1538 | + resp.Diagnostics = resp.Diagnostics.Append(err) |
| 1539 | + break |
| 1540 | + } |
| 1541 | + logger.Trace("GRPCProvider.v6: ReadStateBytes: read bytes of a chunk", n) |
| 1542 | + } |
| 1543 | + |
| 1544 | + logger.Trace("GRPCProvider.v6: ReadStateBytes: received all chunks", buf.Len()) |
| 1545 | + if buf.Len() != expectedTotalLength { |
| 1546 | + err = fmt.Errorf("expected state file of total %d bytes, received %d bytes", |
| 1547 | + expectedTotalLength, buf.Len()) |
| 1548 | + resp.Diagnostics = resp.Diagnostics.Append(err) |
| 1549 | + return resp |
| 1550 | + } |
| 1551 | + resp.Bytes = buf.Bytes() |
| 1552 | + |
| 1553 | + return resp |
| 1554 | +} |
| 1555 | + |
| 1556 | +func (p *GRPCProvider) WriteStateBytes(r providers.WriteStateBytesRequest) (resp providers.WriteStateBytesResponse) { |
| 1557 | + logger.Trace("GRPCProvider.v6: WriteStateBytes") |
| 1558 | + |
| 1559 | + schema := p.GetProviderSchema() |
| 1560 | + if schema.Diagnostics.HasErrors() { |
| 1561 | + resp.Diagnostics = schema.Diagnostics |
| 1562 | + return resp |
| 1563 | + } |
| 1564 | + |
| 1565 | + if _, ok := schema.StateStores[r.TypeName]; !ok { |
| 1566 | + resp.Diagnostics = resp.Diagnostics.Append(fmt.Errorf("unknown state store type %q", r.TypeName)) |
| 1567 | + return resp |
| 1568 | + } |
| 1569 | + |
| 1570 | + // Start the streaming RPC with a context. The context will be cancelled |
| 1571 | + // when this function returns, which will stop the stream if it is still |
| 1572 | + // running. |
| 1573 | + ctx, cancel := context.WithCancel(p.ctx) |
| 1574 | + defer cancel() |
| 1575 | + |
| 1576 | + // TODO: Configurable chunk size |
| 1577 | + chunkSize := 4 * 1_000_000 // 4MB |
| 1578 | + |
| 1579 | + if len(r.Bytes) < chunkSize { |
| 1580 | + protoReq := &proto6.WriteStateBytes_RequestChunk{ |
| 1581 | + TypeName: r.TypeName, |
| 1582 | + StateId: r.StateId, |
| 1583 | + Bytes: r.Bytes, |
| 1584 | + TotalLength: int64(len(r.Bytes)), |
| 1585 | + Range: &proto6.StateRange{ |
| 1586 | + Start: 0, |
| 1587 | + End: int64(len(r.Bytes)), |
| 1588 | + }, |
| 1589 | + } |
| 1590 | + client, err := p.client.WriteStateBytes(ctx) |
| 1591 | + if err != nil { |
| 1592 | + resp.Diagnostics = resp.Diagnostics.Append(grpcErr(err)) |
| 1593 | + return resp |
| 1594 | + } |
| 1595 | + err = client.Send(protoReq) |
| 1596 | + if err != nil { |
| 1597 | + resp.Diagnostics = resp.Diagnostics.Append(grpcErr(err)) |
| 1598 | + return resp |
| 1599 | + } |
| 1600 | + protoResp, err := client.CloseAndRecv() |
| 1601 | + if err != nil { |
| 1602 | + resp.Diagnostics = resp.Diagnostics.Append(grpcErr(err)) |
| 1603 | + return resp |
| 1604 | + } |
| 1605 | + resp.Diagnostics = resp.Diagnostics.Append(convert.ProtoToDiagnostics(protoResp.Diagnostics)) |
| 1606 | + if resp.Diagnostics.HasErrors() { |
| 1607 | + return resp |
| 1608 | + } |
| 1609 | + } |
| 1610 | + |
| 1611 | + // TODO: implement chunking for state files larger than chunkSize |
| 1612 | + |
| 1613 | + return resp |
| 1614 | +} |
| 1615 | + |
1481 | 1616 | func (p *GRPCProvider) GetStates(r providers.GetStatesRequest) (resp providers.GetStatesResponse) {
|
1482 | 1617 | logger.Trace("GRPCProvider.v6: GetStates")
|
1483 | 1618 |
|
|
0 commit comments