Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add machine tag and inference timings #4577

Merged
merged 3 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ message Reply {
bytes message = 1;
int32 tokens = 2;
int32 prompt_tokens = 3;
double timing_prompt_processing = 4;
double timing_token_generation = 5;
}

message ModelOptions {
Expand Down Expand Up @@ -348,4 +350,4 @@ message StatusResponse {
message Message {
string role = 1;
string content = 2;
}
}
14 changes: 14 additions & 0 deletions backend/cpp/llama/grpc-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2414,6 +2414,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);

if (result.result_json.contains("timings")) {
double timing_prompt_processing = result.result_json.at("timings").value("prompt_ms", 0.0);
reply.set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = result.result_json.at("timings").value("predicted_ms", 0.0);
reply.set_timing_token_generation(timing_token_generation);
}

// Log Request Correlation Id
LOG_VERBOSE("correlation:", {
{ "id", data["correlation_id"] }
Expand Down Expand Up @@ -2454,6 +2461,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
reply->set_prompt_tokens(tokens_evaluated);
reply->set_tokens(tokens_predicted);
reply->set_message(completion_text);

if (result.result_json.contains("timings")) {
double timing_prompt_processing = result.result_json.at("timings").value("prompt_ms", 0.0);
reply->set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = result.result_json.at("timings").value("predicted_ms", 0.0);
reply->set_timing_token_generation(timing_token_generation);
}
}
else
{
Expand Down
12 changes: 10 additions & 2 deletions core/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ type LLMResponse struct {
}

type TokenUsage struct {
Prompt int
Completion int
Prompt int
Completion int
TimingPromptProcessing float64
TimingTokenGeneration float64
}

func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
Expand Down Expand Up @@ -123,6 +125,8 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im

tokenUsage.Prompt = int(reply.PromptTokens)
tokenUsage.Completion = int(reply.Tokens)
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing

for len(partialRune) > 0 {
r, size := utf8.DecodeRune(partialRune)
Expand Down Expand Up @@ -157,6 +161,10 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
if tokenUsage.Completion == 0 {
tokenUsage.Completion = int(reply.Tokens)
}

tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing

return LLMResponse{
Response: string(reply.Message),
Usage: tokenUsage,
Expand Down
2 changes: 2 additions & 0 deletions core/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type RunCMD struct {
WatchdogBusyTimeout string `env:"LOCALAI_WATCHDOG_BUSY_TIMEOUT,WATCHDOG_BUSY_TIMEOUT" default:"5m" help:"Threshold beyond which a busy backend should be stopped" group:"backends"`
Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"`
DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
MachineTag string `env:"LOCALAI_MACHINE_TAG" help:"Add Machine-Tag header to each response which is useful to track the machine in the P2P network" group:"api"`
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
}

Expand Down Expand Up @@ -107,6 +108,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithHttpGetExemptedEndpoints(r.HttpGetExemptedEndpoints),
config.WithP2PNetworkID(r.Peer2PeerNetworkID),
config.WithLoadToMemory(r.LoadToMemory),
config.WithMachineTag(r.MachineTag),
}

if r.DisableMetricsEndpoint {
Expand Down
8 changes: 8 additions & 0 deletions core/config/application_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ type ApplicationConfig struct {
ModelsURL []string

WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration

MachineTag string
mudler marked this conversation as resolved.
Show resolved Hide resolved
}

type AppOption func(*ApplicationConfig)
Expand Down Expand Up @@ -94,6 +96,12 @@ func WithModelPath(path string) AppOption {
}
}

func WithMachineTag(tag string) AppOption {
return func(o *ApplicationConfig) {
o.MachineTag = tag
}
}

func WithCors(b bool) AppOption {
return func(o *ApplicationConfig) {
o.CORS = b
Expand Down
8 changes: 8 additions & 0 deletions core/http/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ func API(application *application.Application) (*fiber.App, error) {

router.Use(middleware.StripPathPrefix())

if application.ApplicationConfig().MachineTag != "" {
router.Use(func(c *fiber.Ctx) error {
c.Response().Header.Set("Machine-Tag", application.ApplicationConfig().MachineTag)

return c.Next()
})
}

router.Hooks().OnListen(func(listenData fiber.ListenData) error {
scheme := "http"
if listenData.TLS {
Expand Down
1 change: 0 additions & 1 deletion core/http/endpoints/localai/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
// @Router /tts [post]
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {

input := new(schema.TTSRequest)

// Get input data from the request body
Expand Down
1 change: 0 additions & 1 deletion core/http/endpoints/localai/vad.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
// @Router /vad [post]
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {

input := new(schema.VADRequest)

// Get input data from the request body
Expand Down
59 changes: 39 additions & 20 deletions core/http/endpoints/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
var id, textContentToReturn string
var created int

process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Expand All @@ -40,26 +40,32 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
}
responses <- initialMessage

ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
}
if extraUsage {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any specific reason to put this under a feature flag?

Since the data is already part of the response there is no extra penalty in computation as I can see, and as it's just statistical data it would probably be safe to always return it as part of the response

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i thought that since LocalAI is a drop-in replacement for OpenAI API i think it needs to have strictly the same response models, since some parsers can fail due to extra fields

Copy link
Contributor Author

@mintyleaf mintyleaf Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you think that it's safe - i just remove that feature flag and this just be always included

or that can be made opt-out as well

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. I'm fine to keep it behind a flag, but we would need to update the docs in this case as otherwise would go unnoticed otherwise (as it does not surfaces in the cli help either)

usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
}

resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
Usage: usage,
}

responses <- resp
return true
})
close(responses)
}
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
result := ""
_, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
result += s
Expand Down Expand Up @@ -90,18 +96,23 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
log.Error().Err(err).Msg("error handling question")
return
}
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
}
if extraUsage {
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
}

resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
Usage: usage,
}

responses <- resp
Expand Down Expand Up @@ -170,6 +181,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
}
c.Set("X-Correlation-ID", correlationID)

// Opt-in extra usage flag
extraUsage := c.Get("LocalAI-Extra-Usage", "") != ""

modelFile, input, err := readRequest(c, cl, ml, startupOptions, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
Expand Down Expand Up @@ -319,9 +333,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
responses := make(chan schema.OpenAIResponse)

if !shouldUseFn {
go process(predInput, input, config, ml, responses)
go process(predInput, input, config, ml, responses, extraUsage)
} else {
go processTools(noActionName, predInput, input, config, ml, responses)
go processTools(noActionName, predInput, input, config, ml, responses, extraUsage)
}

c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
Expand Down Expand Up @@ -449,18 +463,23 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
if err != nil {
return err
}
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
}
if extraUsage {
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
}

resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
Usage: usage,
}
respData, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", respData)
Expand Down
44 changes: 29 additions & 15 deletions core/http/endpoints/openai/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,17 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
id := uuid.New().String()
created := int(time.Now().Unix())

process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
}
if extraUsage {
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
}
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Expand All @@ -43,11 +52,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
},
},
Object: "text_completion",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
Usage: usage,
}
log.Debug().Msgf("Sending goroutine: %s", s)

Expand All @@ -60,6 +65,10 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
return func(c *fiber.Ctx) error {
// Add Correlation
c.Set("X-Correlation-ID", id)

// Opt-in extra usage flag
extraUsage := c.Get("LocalAI-Extra-Usage", "") != ""

modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
Expand Down Expand Up @@ -113,7 +122,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e

responses := make(chan schema.OpenAIResponse)

go process(predInput, input, config, ml, responses)
go process(predInput, input, config, ml, responses, extraUsage)

c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {

Expand Down Expand Up @@ -170,23 +179,28 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
return err
}

totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
totalTokenUsage.TimingTokenGeneration += tokenUsage.TimingTokenGeneration
totalTokenUsage.TimingPromptProcessing += tokenUsage.TimingPromptProcessing

result = append(result, r...)
}
usage := schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
}
if extraUsage {
usage.TimingTokenGeneration = totalTokenUsage.TimingTokenGeneration
usage.TimingPromptProcessing = totalTokenUsage.TimingPromptProcessing
}

resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "text_completion",
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
Usage: usage,
}

jsonResult, _ := json.Marshal(resp)
Expand Down
Loading
Loading