Skip to content

Commit

Permalink
Datasource cleanup: introduce some types and avoid pass-by-context (#…
Browse files Browse the repository at this point in the history
…5317)

* Datasource cleanup: introduce some types and avoid pass-by-context

* Fix lint
  • Loading branch information
evankanderson authored Jan 17, 2025
1 parent dc32266 commit 1a425a6
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 104 deletions.
25 changes: 8 additions & 17 deletions internal/datasources/rest/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/mindersec/minder/internal/util/schemaupdate"
"github.com/mindersec/minder/internal/util/schemavalidate"
minderv1 "github.com/mindersec/minder/pkg/api/protobuf/go/minder/v1"
"github.com/mindersec/minder/pkg/engine/v1/interfaces"
)

const (
Expand Down Expand Up @@ -69,7 +70,7 @@ func newHandlerFromDef(def *minderv1.RestDataSource_Def) (*restHandler, error) {
}, nil
}

func (h *restHandler) GetArgsSchema() any {
func (h *restHandler) GetArgsSchema() *structpb.Struct {
return h.rawInputSchema
}

Expand All @@ -86,28 +87,18 @@ func (h *restHandler) ValidateArgs(args any) error {
return schemavalidate.ValidateAgainstSchema(h.inputSchema, mapobj)
}

func (h *restHandler) ValidateUpdate(obj any) error {
if obj == nil {
func (h *restHandler) ValidateUpdate(argsSchema *structpb.Struct) error {
if argsSchema == nil {
return errors.New("update schema cannot be nil")
}

switch castedobj := obj.(type) {
case *structpb.Struct:
if _, err := schemavalidate.CompileSchemaFromPB(castedobj); err != nil {
return fmt.Errorf("update validation failed due to invalid schema: %w", err)
}
return schemaupdate.ValidateSchemaUpdate(h.rawInputSchema, castedobj)
case map[string]any:
if _, err := schemavalidate.CompileSchemaFromMap(castedobj); err != nil {
return fmt.Errorf("update validation failed due to invalid schema: %w", err)
}
return schemaupdate.ValidateSchemaUpdateMap(h.rawInputSchema.AsMap(), castedobj)
default:
return errors.New("invalid type")
if _, err := schemavalidate.CompileSchemaFromPB(argsSchema); err != nil {
return fmt.Errorf("update validation failed due to invalid schema: %w", err)
}
return schemaupdate.ValidateSchemaUpdate(h.rawInputSchema, argsSchema)
}

func (h *restHandler) Call(ctx context.Context, args any) (any, error) {
func (h *restHandler) Call(ctx context.Context, _ *interfaces.Result, args any) (any, error) {
argsMap, ok := args.(map[string]any)
if !ok {
return nil, errors.New("args is not a map")
Expand Down
32 changes: 2 additions & 30 deletions internal/datasources/rest/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func Test_restHandler_Call(t *testing.T) {
headers: tt.fields.headers,
parse: tt.fields.parse,
}
got, err := h.Call(context.Background(), tt.args.args)
got, err := h.Call(context.Background(), nil, tt.args.args)
if tt.wantErr {
assert.Error(t, err)
} else {
Expand Down Expand Up @@ -367,7 +367,7 @@ func Test_restHandler_ValidateUpdate(t *testing.T) {
t.Parallel()

type args struct {
updateSchema any
updateSchema *structpb.Struct
}
tests := []struct {
name string
Expand Down Expand Up @@ -408,34 +408,6 @@ func Test_restHandler_ValidateUpdate(t *testing.T) {
},
wantErr: false,
},
{
name: "Valid map[string]any",
inputSchema: map[string]any{
"type": "object",
"properties": map[string]any{"key": map[string]any{"type": "string"}},
},
args: args{
updateSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"key": map[string]any{"type": "string"},
"new_key": map[string]any{"type": "number"},
},
},
},
wantErr: false,
},
{
name: "Invalid type",
inputSchema: map[string]any{
"type": "object",
"properties": map[string]any{"key": map[string]any{"type": "string"}},
},
args: args{
updateSchema: "invalid_type",
},
wantErr: true,
},
{
name: "nil update schema",
inputSchema: map[string]any{
Expand Down
18 changes: 7 additions & 11 deletions internal/datasources/structured/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ import (

"github.com/go-git/go-billy/v5"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/structpb"

minderv1 "github.com/mindersec/minder/pkg/api/protobuf/go/minder/v1"
v1datasources "github.com/mindersec/minder/pkg/datasources/v1"
"github.com/mindersec/minder/pkg/engine/v1/interfaces"
)

const (
Expand Down Expand Up @@ -148,21 +150,15 @@ func parseFile(f billy.File) (any, error) {
}

// Call parses the structured data from the billy filesystem in the context
func (sh *structHandler) Call(ctx context.Context, _ any) (any, error) {
var ctxData v1datasources.Context
var ok bool
if ctxData, ok = ctx.Value(v1datasources.ContextKey{}).(v1datasources.Context); !ok {
return nil, fmt.Errorf("unable to read execution context")
}

if ctxData.Ingest == nil || ctxData.Ingest.Fs == nil {
func (sh *structHandler) Call(_ context.Context, ingest *interfaces.Result, _ any) (any, error) {
if ingest == nil || ingest.Fs == nil {
return nil, fmt.Errorf("filesystem not found in execution context")
}

return parseFileAlternatives(ctxData.Ingest.Fs, sh.Path.GetFileName(), sh.Path.GetAlternatives())
return parseFileAlternatives(ingest.Fs, sh.Path.GetFileName(), sh.Path.GetAlternatives())
}

func (*structHandler) GetArgsSchema() any {
func (*structHandler) GetArgsSchema() *structpb.Struct {
return nil
}

Expand All @@ -172,6 +168,6 @@ func (_ *structHandler) ValidateArgs(any) error {
}

// ValidateUpdate
func (_ *structHandler) ValidateUpdate(any) error {
func (_ *structHandler) ValidateUpdate(*structpb.Struct) error {
return nil
}
35 changes: 12 additions & 23 deletions internal/datasources/structured/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/stretchr/testify/require"

minderv1 "github.com/mindersec/minder/pkg/api/protobuf/go/minder/v1"
v1datasources "github.com/mindersec/minder/pkg/datasources/v1"
"github.com/mindersec/minder/pkg/engine/v1/interfaces"
)

Expand Down Expand Up @@ -166,25 +165,19 @@ func TestNew(t *testing.T) {
func TestCall(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
buildContext func(t *testing.T) context.Context
def *minderv1.StructDataSource_Def
mustErr bool
name string
ingest func(t *testing.T) *interfaces.Result
def *minderv1.StructDataSource_Def
mustErr bool
}{
{
"success",
func(t *testing.T) context.Context {
func(t *testing.T) *interfaces.Result {
t.Helper()
fs := memfs.New()
writeFSFile(t, fs, "./test1.json", []byte("{ \"a\": \"b\"}"))

return context.WithValue(
context.Background(),
v1datasources.ContextKey{},
v1datasources.Context{
Ingest: &interfaces.Result{Fs: fs},
},
)
return &interfaces.Result{Fs: fs}
},
&minderv1.StructDataSource_Def{
Path: &minderv1.StructDataSource_Def_Path{
Expand All @@ -195,31 +188,27 @@ func TestCall(t *testing.T) {
},
{
"no-datasource-context",
func(t *testing.T) context.Context {
func(t *testing.T) *interfaces.Result {
t.Helper()
return context.Background()
return nil
},
&minderv1.StructDataSource_Def{},
true,
},
{"ctx-no-fs",
func(t *testing.T) context.Context {
func(t *testing.T) *interfaces.Result {
t.Helper()
return context.WithValue(
context.Background(),
v1datasources.ContextKey{},
v1datasources.Context{},
)
return &interfaces.Result{}
},
&minderv1.StructDataSource_Def{},
true},
} {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := tc.buildContext(t)
ingest := tc.ingest(t)
handler, err := newHandlerFromDef(tc.def)
require.NoError(t, err)
_, err = handler.Call(ctx, []string{})
_, err = handler.Call(context.Background(), ingest, []string{})
if tc.mustErr {
require.Error(t, err)
return
Expand Down
13 changes: 2 additions & 11 deletions internal/engine/eval/rego/datasources.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package rego

import (
"context"
"fmt"
"strings"

Expand Down Expand Up @@ -48,7 +47,7 @@ func buildFromDataSource(
Name: k,
Decl: types.NewFunction(types.Args(types.A), types.A),
},
func(_ rego.BuiltinContext, obj *ast.Term) (*ast.Term, error) {
func(bctx rego.BuiltinContext, obj *ast.Term) (*ast.Term, error) {
// Convert the AST value back to a Go interface{}
jsonObj, err := ast.JSON(obj.Value)
if err != nil {
Expand All @@ -59,15 +58,7 @@ func buildFromDataSource(
return nil, err
}

// Call the data source function
ctx := context.WithValue(
context.Background(),
v1datasources.ContextKey{},
v1datasources.Context{
Ingest: res,
},
)
ret, err := dsf.Call(ctx, jsonObj)
ret, err := dsf.Call(bctx.Context, res, jsonObj)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions internal/engine/eval/rego/rego_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ allow {
emptyPol := map[string]any{}

// Matches
fdsf.EXPECT().Call(gomock.Any(), gomock.Any()).Return("foo", nil)
fdsf.EXPECT().Call(gomock.Any(), gomock.Any(), gomock.Any()).Return("foo", nil)
_, err = e.Eval(context.Background(), emptyPol, nil, &interfaces.Result{
Object: map[string]any{
"data": "foo",
Expand All @@ -531,7 +531,7 @@ allow {
require.NoError(t, err, "could not evaluate")

// Doesn't match
fdsf.EXPECT().Call(gomock.Any(), gomock.Any()).Return("bar", nil)
fdsf.EXPECT().Call(gomock.Any(), gomock.Any(), gomock.Any()).Return("bar", nil)
_, err = e.Eval(context.Background(), emptyPol, nil, &interfaces.Result{
Object: map[string]any{
"data": "bar",
Expand Down
8 changes: 5 additions & 3 deletions pkg/datasources/v1/datasources.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ package v1
import (
"context"

"google.golang.org/protobuf/types/known/structpb"

"github.com/mindersec/minder/pkg/engine/v1/interfaces"
)

Expand Down Expand Up @@ -36,14 +38,14 @@ type DataSourceFuncDef interface {
// ValidateUpdate validates the update to the data source.
// The data source implementation should respect the update and return an error
// if the update is invalid.
ValidateUpdate(obj any) error
ValidateUpdate(obj *structpb.Struct) error
// Call calls the function with the given arguments.
// It is the responsibility of the data source implementation to handle the call.
// It is also the responsibility of the caller to validate the arguments
// before calling the function.
Call(ctx context.Context, args any) (any, error)
Call(ctx context.Context, ingest *interfaces.Result, args any) (any, error)
// GetArgsSchema returns the schema of the arguments.
GetArgsSchema() any
GetArgsSchema() *structpb.Struct
}

// DataSource is the interface that a data source must implement.
Expand Down
16 changes: 9 additions & 7 deletions pkg/datasources/v1/mock/datasources.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1a425a6

Please sign in to comment.