|
|
@@ -32,14 +32,14 @@ func (c *ApiController) ResponseErrorStream(errorText string) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func getModelProviderFromContext(owner string, name string) (model.ModelProvider, error) { |
|
|
|
func getModelProviderFromContext(owner string, name string) (*object.Provider, model.ModelProvider, error) { |
|
|
|
var providerName string |
|
|
|
if name != "" { |
|
|
|
providerName = name |
|
|
|
} else { |
|
|
|
store, err := object.GetDefaultStore(owner) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
return nil, nil, err |
|
|
|
} |
|
|
|
|
|
|
|
if store != nil && store.ModelProvider != "" { |
|
|
@@ -57,28 +57,28 @@ func getModelProviderFromContext(owner string, name string) (model.ModelProvider |
|
|
|
} |
|
|
|
|
|
|
|
if provider == nil && err == nil { |
|
|
|
return nil, fmt.Errorf("The model provider: %s is not found", providerName) |
|
|
|
return nil, nil, fmt.Errorf("The model provider: %s is not found", providerName) |
|
|
|
} |
|
|
|
if provider.Category != "Model" || provider.ClientSecret == "" { |
|
|
|
return nil, fmt.Errorf("The model provider: %s is invalid", providerName) |
|
|
|
return nil, nil, fmt.Errorf("The model provider: %s is invalid", providerName) |
|
|
|
} |
|
|
|
|
|
|
|
providerObj, err := provider.GetModelProvider() |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
return nil, nil, err |
|
|
|
} |
|
|
|
|
|
|
|
return providerObj, err |
|
|
|
return provider, providerObj, err |
|
|
|
} |
|
|
|
|
|
|
|
func getEmbeddingProviderFromContext(owner string, name string) (embedding.EmbeddingProvider, error) { |
|
|
|
func getEmbeddingProviderFromContext(owner string, name string) (*object.Provider, embedding.EmbeddingProvider, error) { |
|
|
|
var providerName string |
|
|
|
if name != "" { |
|
|
|
providerName = name |
|
|
|
} else { |
|
|
|
store, err := object.GetDefaultStore(owner) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
return nil, nil, err |
|
|
|
} |
|
|
|
|
|
|
|
if store != nil && store.EmbeddingProvider != "" { |
|
|
@@ -96,16 +96,16 @@ func getEmbeddingProviderFromContext(owner string, name string) (embedding.Embed |
|
|
|
} |
|
|
|
|
|
|
|
if provider == nil && err == nil { |
|
|
|
return nil, fmt.Errorf("The embedding provider: %s is not found", providerName) |
|
|
|
return nil, nil, fmt.Errorf("The embedding provider: %s is not found", providerName) |
|
|
|
} |
|
|
|
if provider.Category != "Embedding" || provider.ClientSecret == "" { |
|
|
|
return nil, fmt.Errorf("The embedding provider: %s is invalid", providerName) |
|
|
|
return nil, nil, fmt.Errorf("The embedding provider: %s is invalid", providerName) |
|
|
|
} |
|
|
|
|
|
|
|
providerObj, err := provider.GetEmbeddingProvider() |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
return nil, nil, err |
|
|
|
} |
|
|
|
|
|
|
|
return providerObj, err |
|
|
|
return provider, providerObj, err |
|
|
|
} |